In [None]:
import argparse
import json
import logging
import os
import random
import time

import torch
import numpy as np

In [None]:
import sys 
sys.path.append('..')
import utils.misc as ws
import utils.data_utils
import utils.train_utils
import utils.eval_utils
import utils.mesh
import utils.dataset as d
import models.networks as arch


In [None]:
DATA_SOURCE = '/home/ninad/Desktop/Docs/phd-res/proj-irvl-grasp-transfer/code/docker-data/output_dataset/'
# DATA_SOURCE = '/home/ninad/Desktop/multi-finger-grasping/output_dataset/'
EXPERIMENTS_DIR = '../experiments/all3_gemb_varcmap_n200_run3_increased_batch_size_80/'
CHECKPOINT = 'latest'
split_filename = os.path.join(EXPERIMENTS_DIR, 'split_train.json')
specs_filename = os.path.join(EXPERIMENTS_DIR, "specs.json")

LATENT_CODE_DIR = ws.latent_codes_subdir

In [None]:
specs = json.load(open(specs_filename))
latent_size = specs["CodeLength"]
gripper_weight = specs["GripperWeight"]
num_grippers = specs["NumGrippers"]
grp_embedding_size = specs.get("GripperEmbeddingLength", 30)

In [None]:
decoder = arch.dsdfDecoder(
    grp_embedding_size + latent_size, 
    **specs["NetworkSpecs"]
    ).cuda()

decoder = torch.nn.DataParallel(decoder)

saved_model_state = torch.load(
    os.path.join(
        EXPERIMENTS_DIR, ws.model_params_subdir, CHECKPOINT + ".pth")
)

saved_model_epoch = saved_model_state["epoch"]

decoder.load_state_dict(saved_model_state["model_state_dict"])

decoder = decoder.module.cuda()

In [None]:
with open(split_filename, "r") as f:
    split = json.load(f)

# npz_filenames = utils.data_utils.dsdf_get_instance_filenames(
#     args.data_source, split)
cmap_f, grp_names, gpc_f, npz_filenames = utils.data_utils.get_instance_filelist(DATA_SOURCE, split)

# random.shuffle(npz_filenames) # WHY??? DISABLE THIS FOR CHECKING REPRODUCIBILITY

In [None]:
for f in npz_filenames[1:10]:
    print(f[-35:])

In [None]:
latent_vecs = ws.load_latent_vectors(EXPERIMENTS_DIR, CHECKPOINT)
gripper_vecs = ws.load_gripper_vectors(EXPERIMENTS_DIR, CHECKPOINT)
print(latent_vecs.shape)
print(gripper_vecs.shape)

In [None]:
# index_to_select_1 = random.randint(0, len(npz_filenames)-1)
# index_to_select_2 = random.randint(0, len(npz_filenames)-1)

index_to_select_1 = 31
index_to_select_2 = 500

print(index_to_select_1, index_to_select_2)

npz_1 = npz_filenames[index_to_select_1]
npz_2 = npz_filenames[index_to_select_2]

full_filename = npz_1

print(index_to_select_1, npz_1[-35:])
print(index_to_select_2, npz_2[-35:])


In [None]:
gripper_idx = 2

latent_vec_1 = latent_vecs[index_to_select_1].unsqueeze(1)
latent_vec_2 = latent_vecs[index_to_select_2].unsqueeze(1)

grp_vec_1 = gripper_vecs[2].unsqueeze(1)
grp_vec_2 = gripper_vecs[1].unsqueeze(1)


In [None]:
combined_latent_code_1 = torch.cat([grp_vec_1, latent_vec_1], 0)
combined_latent_code_2 = torch.cat([grp_vec_2, latent_vec_2], 0)
print(combined_latent_code_1.shape)
# print(combined_latent_code)

In [None]:
combined_latent_code_1 = combined_latent_code_1.squeeze()
combined_latent_code_2 = combined_latent_code_2.squeeze()
print(combined_latent_code_1.shape)

In [None]:
# ####### MESH RECONSTRUCTION CODE!

# is_gripper = True

# if is_gripper:
#     mesh_filename = os.path.join(EXPERIMENTS_DIR, f'test_{npz[-15:]}')
# else:
#     mesh_filename = os.path.join(EXPERIMENTS_DIR, f'test_o_{npz[-15:]}')


# latent_vec = latent_vec.squeeze().cuda()
# with torch.no_grad():
#     utils.mesh.create_mesh_custom(
#         decoder, latent_vec, mesh_filename, N=256, max_batch=int(2 ** 18), isGripper=is_gripper)


In [None]:
# IMPORTANT: HAVE SHUFFLE = FALSE SO THAT SAME DATA POINT IS LOADED AS index_to_select

# sdf_dataset = d.SDFSamples(DATA_SOURCE, split, 1000000)

sdf_dataset = d.MultiGripperSamples(DATA_SOURCE, split, 100000)

sdf_loader = torch.utils.data.DataLoader(
    sdf_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=8,
    drop_last=True )

In [None]:
_gidx_1, _, samples_1, idx_1, npzfile_1 = sdf_dataset[index_to_select_1]
_gidx_1, _, samples_2, idx_2, npzfile_2 = sdf_dataset[index_to_select_2]
print(npzfile_1[-40:], npzfile_2[-40:])

# queries = samples[:, :3] # Need to pass this through the network
# gt_sdf_obj = samples[:, 3].squeeze().numpy()
# gt_sdf_grp = samples[:, 4].squeeze().numpy()
# print(gt_sdf_grp.shape)

In [None]:
# print(samples.shape)

In [None]:
def get_cvx_combination(l1, l2, alpha):
    return l1 * alpha + (1 - alpha) * l2

In [None]:
alpha_list = np.linspace(0, 1, 11)

In [None]:
# with torch.no_grad():
#     queries, sdf_obj, sdf_grp = utils.eval_utils.eval_query_pc(decoder, latent_vec.cuda(), queries)

In [None]:
# idx = 4
# cvx_code = get_cvx_combination(latent_vec_1, latent_vec_2, alpha_list[idx])
# cvx_code = cvx_code.cuda()

In [None]:
# queries, sdf_obj, sdf_grp = utils.eval_utils.eval_random_query_pc(decoder, cvx_code, num_samples=1000000)

In [None]:
latent_codes = [get_cvx_combination(combined_latent_code_1, combined_latent_code_2, alpha_list[i]) for i in range(len(alpha_list))]

results = [utils.eval_utils.eval_random_query_pc(decoder, latent_codes[i].cuda(), num_samples=100000) for i in range(len(alpha_list))]



In [None]:
idx = 7

queries, sdf_obj, sdf_grp = results[idx]
queries = queries.detach().cpu().numpy()
sdf_obj = sdf_obj.detach().cpu().numpy()
sdf_grp = sdf_grp.detach().cpu().numpy()

In [None]:
print(sdf_grp.shape)
print(sdf_obj.shape)
print(queries.shape)

In [None]:
plot_sdf(queries[sdf_obj < 1e-4], sdf_obj)

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline

def plot_sdf_2(xyz, sdf, xyz2, sdf2, title='Sample_Title', n_display=10000, fname=None):
    fig = plt.figure(figsize = (10,10))
    ax = fig.add_subplot(111, projection='3d')    

    ind = np.random.choice(range(xyz.shape[0]), n_display)
    data = xyz[ind].T
    

    ind2 = np.random.choice(range(xyz2.shape[0]), n_display)
    data2 = xyz2[ind].T
    
    ax.scatter(data[0], data[2], data[1], s=5, c=sdf[ind])
    ax.view_init(20, 100)
    limit = (-0.95, 0.95)
    ax.set_xlim3d(*limit)
    ax.set_ylim3d(*limit)
    ax.set_zlim3d(*limit)
    plt.title(title)
    if fname:
        plt.savefig(fname)
    else:
        plt.show()


def plot_sdf(xyz, sdf, title='Sample_Title', n_display=10000, fname=None):
    fig = plt.figure(figsize = (10,10))
    ax = fig.add_subplot(111, projection='3d')    

    ind = np.random.choice(range(xyz.shape[0]), n_display)
    data = xyz[ind].T

    ax.scatter(data[0], data[2], data[1], s=5, c=sdf[ind])
    ax.view_init(20, 100)
    limit = (-0.95, 0.95)
    ax.set_xlim3d(*limit)
    ax.set_ylim3d(*limit)
    ax.set_zlim3d(*limit)
    plt.title(title)
    if fname:
        plt.savefig(fname)
    else:
        plt.show()


def plot_sdf_with_box(xyz, sdf, box_pts, title='Sample_Title', n_display=10000, fname=None):
    fig = plt.figure(figsize = (10,10))
    ax = fig.add_subplot(111, projection='3d')    

    ind = np.random.choice(range(xyz.shape[0]), n_display)
    data = xyz[ind].T
    
    ax.scatter(data[0], data[2], data[1], s=5, c=sdf[ind])
    ax.scatter(box_pts[0], box_pts[2], box_pts[1], s=2, c='red')
    
    ax.view_init(20, 100)
    limit = (-0.95, 0.95)
    ax.set_xlim3d(*limit)
    ax.set_ylim3d(*limit)
    ax.set_zlim3d(*limit)
    plt.title(title)
    if fname:
        plt.savefig(fname)
    else:
        plt.show()

In [None]:
box_data = queries[sdf_obj < 1e-3]
indices = np.random.choice(range(box_data.shape[0]), 1000)
box_data = box_data[indices].T

for i in range(len(alpha_list)):
    queries, sdf_obj, sdf_grp = results[i]
    queries = queries.detach().cpu().numpy()
    sdf_obj = sdf_obj.detach().cpu().numpy()
    sdf_grp = sdf_grp.detach().cpu().numpy()    
    EPS = -1e-4
    ind_grp = sdf_grp <= EPS
    
    fname_save = os.path.join(EXPERIMENTS_DIR, 'viz', 
                              f'interpolation_{index_to_select_1}_{index_to_select_2}_{alpha_list[i]:.1f}.png')
    
    plot_sdf_with_box(queries[ind_grp], sdf_obj[ind_grp], box_data, 
             title=f'Grasps={index_to_select_1, index_to_select_2} ; Coef={alpha_list[i]:.1f}',
             fname=fname_save)

    
    
# EPS = -1e-4
# ind_grp = sdf_grp <= EPS
# plot_sdf(queries[ind_grp], sdf_obj[ind_grp])

In [None]:
import pyrender
# Just pass the points you want to visualize
def plt_points_3d(pts):
    colors = np.zeros(pts.shape)
    cloud = pyrender.Mesh.from_points(pts, colors=colors)
    scene = pyrender.Scene()
    scene.add(cloud)
    viewer = pyrender.Viewer(scene, use_raymond_lighting=True, point_size=2)

# Pass the point as well as the sdf to see inside/outside points
def plt_points_sdf(pts, sdf, eps=1e-4):
    colors = np.zeros(pts.shape)
    colors[sdf < eps, 1] = 1
    colors[sdf > eps, 0] = 1
    cloud = pyrender.Mesh.from_points(pts, colors=colors)
    scene = pyrender.Scene()
    scene.add(cloud)
    viewer = pyrender.Viewer(scene, use_raymond_lighting=True, point_size=2)

def plt_points_sdf_compare(pts, sdf_gt, sdf_pred):
    colors = np.zeros(pts.shape)
    colors[sdf_gt < 0, 1] = 1
    colors[sdf_gt > 0, 0] = 1
    cloud = pyrender.Mesh.from_points(pts, colors=colors)
    scene = pyrender.Scene()
    scene.add(cloud)
    viewer = pyrender.Viewer(scene, use_raymond_lighting=True, point_size=2)
    
    colors2 = np.zeros(pts.shape)
    colors2[sdf_pred < 0, 1] = 1
    colors2[sdf_pred > 0, 0] = 1
    cloud2 = pyrender.Mesh.from_points(pts, colors=colors2)
    scene2 = pyrender.Scene()
    scene2.add(cloud2)
    viewer = pyrender.Viewer(scene2, use_raymond_lighting=True, point_size=2)

In [None]:
plt_points_3d(queries[sdf_grp < 1e-4])

In [None]:
plt_points_3d(queries[sdf_obj < 1e-4])

In [None]:
plt_points_sdf(queries, sdf_grp)

In [None]:
plt_points_sdf(queries, sdf_obj)