In [1]:
import argparse
import json
import logging
import os
import random
import time
import textwrap

import torch
import numpy as np

from scipy.spatial.transform import Rotation
from scipy.spatial.distance import pdist, cdist, squareform

from PIL import Image
from PIL.Image import Image as PilImage
import matplotlib.pyplot as plt

In [2]:
import sys 

def add_path(path):
    if path not in sys.path:
        sys.path.insert(0, path)

THIS_DIR = os.path.dirname('./')
LIB_PATH = os.path.join(THIS_DIR, '..')
add_path(LIB_PATH)

import utils.misc as ws
import utils.data_utils
import utils.train_utils
import utils.eval_utils
import utils.mesh
import utils.dataset


In [3]:
def get_actual_idx(training_contact_maps, global_idx):
    _path_split = training_contact_maps[global_idx].split("/")
    return (_path_split[-2], _path_split[-1].split(".")[0].split("_")[-1])

In [4]:
dataset_dir = os.path.join(LIB_PATH, '../dataset_train/')
validation_dir = os.path.join(LIB_PATH, '../dataset_validation/')

# object_model = "003_cracker_box_google_16k_textured_scale_1000"
object_model = "005_tomato_soup_can_google_16k_textured_scale_1000"

EXPERIMENTS_DIR = '../experiments/all5_005_dsdf_50_varcmap'

images_dir = os.path.join(dataset_dir, object_model, "images")
contactmap_dir = os.path.join(dataset_dir, object_model, "contactmap")

CHECKPOINT = 'latest'
split_file = os.path.join(EXPERIMENTS_DIR, 'split_train.json')
specs_filename = os.path.join(EXPERIMENTS_DIR, "specs.json")

with open(split_file, 'r') as f:
    data_split = json.load(f)


LATENT_CODE_DIR = ws.latent_codes_subdir

In [5]:
# This has all the necessary information
sdf_dataset = utils.dataset.MultiGripperSamples(dataset_dir, data_split, subsample=16000)
trn_cmap_dist = sdf_dataset.cmap_dist
trn_cmap_sim = 1 - trn_cmap_dist

gripper idxs: {'Allegro': 0, 'Barrett': 1, 'HumanHand': 2, 'fetch_gripper': 3, 'panda_grasp': 4}


In [6]:
# Load the latent vectors and construct a N,N distance matrix between them

specs = json.load(open(specs_filename))
latent_size = specs["CodeLength"]
latent_vecs = ws.load_latent_vectors(EXPERIMENTS_DIR, CHECKPOINT)
print(latent_size, latent_vecs.shape)

lv_numpy = latent_vecs.cpu().numpy()

lv_dist = squareform(pdist(lv_numpy, metric='cityblock'))
lv_dist /= np.max(lv_dist)
print(lv_dist.shape)

128 torch.Size([250, 128])
(250, 250)


In [24]:
def check_matches(cmap_dist, lv_dist, N, extreme_end=5):
    count_near = 0
    count_far = 0

    setdiff_near = 0.0
    setdiff_far = 0.0

    count_top1 = 0
    count_topK = 0

    for q in range(N):
        topK_query_cmap = np.argsort(cmap_dist[q])
        topK_query_lv = np.argsort(lv_dist[q])
        
        # if set(topK_query_cmap[1:extreme_end+1]) == set(topK_query_lv[1:extreme_end+1]):
        #     count_near += 1
        # if set(topK_query_cmap[-extreme_end:]) == set(topK_query_lv[-extreme_end:]):
        #     count_far += 1
        
        # Count the different elements
        setdiff_near += len(set(topK_query_cmap[1:extreme_end+1]) ^ set(topK_query_lv[1:extreme_end+1]))
        setdiff_far += len(set(topK_query_cmap[-extreme_end:]) ^ set(topK_query_lv[-extreme_end:]))
        
        # top-1
        if topK_query_cmap[1] == topK_query_lv[1]:
            count_top1 += 1
        
        # top-5 -- [:6] since first element is the same as query!
        if topK_query_cmap[1] in topK_query_lv[1:extreme_end+1]:
            count_topK += 1

    setdiff_near /= N
    setdiff_far /= N
    
    # return count_near, count_far, setdiff_near, setdiff_far, count_top1/N, count_top5/N
    return setdiff_near, setdiff_far, count_top1/N, count_topK/N


def check_avg_sim(cmap_dist, lv_dist, cmap_sim, K=1):
    sim_topK = 0
    sim_cmap = 0
    far_topK = 0
    far_cmap = 0
    N = len(cmap_dist) # number of validation samples
    for q in range(N):
        topK_query_cmap = np.argsort(cmap_dist[q])
        topK_query_lv = np.argsort(lv_dist[q])
        # Nearest
        sim_topK += np.mean(cmap_sim[q][topK_query_lv[1:K+1]])
        sim_cmap += np.mean(cmap_sim[q][topK_query_cmap[1:K+1]])
        # Farthest
        far_topK += np.mean(cmap_sim[q][topK_query_lv[-K:]])
        far_cmap += np.mean(cmap_sim[q][topK_query_cmap[-K:]])
    return sim_topK/N, sim_cmap/N, far_topK/N, far_cmap/N

In [11]:
print(check_matches(sdf_dataset.cmap_dist, lv_dist, len(sdf_dataset), extreme_end=2))

print(check_matches(sdf_dataset.cmap_dist, lv_dist, len(sdf_dataset), extreme_end=3))

print(check_matches(sdf_dataset.cmap_dist, lv_dist, len(sdf_dataset), extreme_end=5))

print(check_matches(sdf_dataset.cmap_dist, lv_dist, len(sdf_dataset), extreme_end=8))

(1.304, 1.696, 0.64, 0.832)
(1.88, 2.208, 0.64, 0.904)
(2.792, 3.232, 0.64, 0.932)
(4.176, 4.672, 0.64, 0.968)


In [25]:
print("K=1", check_avg_sim(trn_cmap_dist, lv_dist, trn_cmap_sim, 1))
print("K=3", check_avg_sim(trn_cmap_dist, lv_dist, trn_cmap_sim, 3))
print("K=5", check_avg_sim(trn_cmap_dist, lv_dist, trn_cmap_sim, 5))

K=1 (0.7829037137586227, 0.791538979273289, 0.12050250911444393, 0.09844150558390569)
K=3 (0.7529259250979239, 0.7617268475424762, 0.14014646755108362, 0.12377894410689155)
K=5 (0.7361010385875519, 0.7439354169143236, 0.15567472821006892, 0.1418084257351176)


In [None]:
query_idx = 188

query_gripper, query_gnum = get_actual_idx(sdf_dataset.cmaps, query_idx)

topK_query_cmap = np.argsort(sdf_dataset.cmap_dist[query_idx])

topK_query_lv = np.argsort(lv_dist[query_idx])


In [None]:
K = 15
print("Nearest")
print("GT Cmap:", topK_query_cmap[:K])
print("Pred LV:", topK_query_lv[:K])

print("\nFarthest")
print("GT Cmap:", topK_query_cmap[-K:])
print("Pred LV:", topK_query_lv[-K:])

print(set(topK_query_cmap[:K]) == set(topK_query_lv[:K]))

In [None]:
# Images

def get_extreme_matches(trn_cmaps, topK_list, k=10):
    top_K_close = []
    bot_K_away = []

    # 1 to K+1 since 0 index correspnds to the query itself
    for i in range(1, k+1):
        gripper, gnum = get_actual_idx(trn_cmaps, topK_list[i])    
        img_f = os.path.join(images_dir, gripper, f"img_graspnum_{gnum}.png")
        top_K_close.append(img_f)
        
        rev_idx = -i
        f_gripper, f_gnum = get_actual_idx(trn_cmaps, topK_list[rev_idx])    
        far_img_f = os.path.join(images_dir, f_gripper, f"img_graspnum_{f_gnum}.png")
        bot_K_away.append(far_img_f)
    
    return top_K_close, bot_K_away
    

close_cmap, far_cmap = get_extreme_matches(sdf_dataset.cmaps, topK_query_cmap, k=10)

close_lv, far_lv = get_extreme_matches(sdf_dataset.cmaps, topK_query_lv, k=10)


# Populate the PIL images

## For the cmap (ground truth)
imgs_close_cmap = [Image.open(_img) for _img in close_cmap]
imgs_far_cmap = [Image.open(_img) for _img in far_cmap]

## For the latent vectors
imgs_close_lv = [Image.open(_img) for _img in close_lv]
imgs_far_lv = [Image.open(_img) for _img in far_lv]

In [None]:
# Reference: https://keestalkstech.com/2020/05/plotting-a-grid-of-pil-images-in-jupyter/

def display_images(
    images: list,
    overall_title='Sample title',
    columns=6, width=20, height=8, max_images=15, 
    label_wrap_length=50, label_font_size=8):

    if len(images) > max_images:
        print(f"Showing {max_images} images of {len(images)}:")
        images=images[0:max_images]

    height = 1 + max(height, int(len(images)/columns) * height)
    plt.figure(figsize=(width, height))
    plt.title(overall_title)
    for i, image in enumerate(images):

        plt.subplot(int(len(images) / columns + 1), columns, i + 1)
        plt.imshow(image)

        if hasattr(image, 'filename'):
            title=image.filename
            if title.endswith("/"): title = title[0:-1]
            title=os.path.basename(title)
            title=textwrap.wrap(title, label_wrap_length)
            title="\n".join(title)
            plt.title(title, fontsize=label_font_size)

# Function to display the images in a grid

In [None]:
query_img_f = os.path.join(images_dir, query_gripper, f"img_graspnum_{query_gnum}.png")
q_img = Image.open(query_img_f)
print("QUERY IMAGE")
plt.imshow(q_img)

In [None]:
print("CLOSEST according to Contact Map (G.T)")

display_images(imgs_close_cmap, overall_title="Closest Cmap (GT)")

In [None]:
print("CLOSEST according to Learned latent vector (Z) space")

display_images(imgs_close_lv, overall_title="Closest L.V")

In [None]:
print("FARTHEST according to Contact Map (G.T)")

display_images(imgs_far_cmap, overall_title="FARTHEST Cmap (GT)")

In [None]:
print("FARTHEST according to Learned latent vector (Z) space")

display_images(imgs_far_lv, overall_title="FARTHEST L.V")