In [2]:
import torch
from utils import *
from collections import defaultdict
import matplotlib.pyplot as plt
import time

from models.rendering import *
from models.nerf import *

import metrics

from datasets import dataset_dict
from datasets.llff import *

torch.backends.cudnn.benchmark = True

img_wh = (200, 200)


dataset = dataset_dict['blender'] \
          ('./data/nerf_synthetic/mug/', 'test',
           img_wh=img_wh)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [2]:
hand_embedding_xyz = Embedding(3, 10)
hand_embedding_dir = Embedding(3, 4)
object_embedding_xyz = Embedding(3, 10)
object_embedding_dir = Embedding(3, 4)

hand_nerf_coarse = NeRF()
hand_nerf_fine = NeRF()
object_nerf_coarse = NeRF()
object_nerf_fine = NeRF()

hand_ckpt_path = './ckpts/hand_flat/epoch=7.ckpt'
object_ckpt_path = './ckpts/mug/epoch=7.ckpt'



load_ckpt(hand_nerf_coarse, hand_ckpt_path, model_name='nerf_coarse')
load_ckpt(hand_nerf_fine, hand_ckpt_path, model_name='nerf_fine')
load_ckpt(object_nerf_coarse, object_ckpt_path, model_name='nerf_coarse')
load_ckpt(object_nerf_fine, object_ckpt_path, model_name='nerf_fine')

hand_nerf_coarse.cuda().eval()
hand_nerf_fine.cuda().eval()
object_nerf_coarse.cuda().eval()
object_nerf_fine.cuda().eval();

In [3]:
from models.render_blend_mesh import render_rays_blend
hand_models = [hand_nerf_coarse, hand_nerf_fine]
hand_embeddings = [hand_embedding_xyz, hand_embedding_dir]
object_models = [object_nerf_coarse, object_nerf_fine]
object_embeddings = [object_embedding_xyz, object_embedding_dir]

N_samples = 64
N_importance = 64
use_disp = False
chunk = 1024*32*4

@torch.no_grad()
def f_trans(rays, poses = None, mano_layer = None, global_translation = None):
    """Do batched inference on rays using chunk."""
    B = rays.shape[0]
    results = defaultdict(list)
    for i in range(0, B, chunk):
        rendered_ray_chunks = \
            render_rays_blend(hand_models,
                        hand_embeddings,
                        object_models,
                        object_embeddings,
                        rays[i:i+chunk],
                        N_samples,
                        use_disp,
                        0,
                        0,
                        N_importance,
                        chunk,
                        dataset.white_back,
                        test_time=True,
                        poses=poses,
                        mano_layer=mano_layer,
                        global_translation=global_translation)

        for k, v in rendered_ray_chunks.items():
            results[k] += [v]

    for k, v in results.items():
        results[k] = torch.cat(v, 0)
    return results

In [4]:
from manopth.manolayer import ManoLayer

sample = dataset[0] # 18
rays = sample['rays'].cuda()

ncomps = 45

import json

result_dict = {}
with open('./param.json', 'r') as f:
    result_dict = json.load(f)

poses_final = [result_dict['rot'] + result_dict['thetas']]
poses_final = torch.tensor(poses_final)
poses_init = torch.zeros_like(poses_final)
global_translation_final = torch.tensor([result_dict['trans']]) * 12
global_translation_final = global_translation_final.cuda()
global_translation_init = torch.zeros_like(global_translation_final)

mano_layer = ManoLayer(mano_root='./mano/models', use_pca=False, ncomps=ncomps, flat_hand_mean=True)
shapes = torch.zeros(1, 10)

  torch.Tensor(smpl_data['betas'].r).unsqueeze(0))


In [5]:
import imageio
from tqdm import tqdm

total = 80

for i in tqdm(range(total + 1)):
    poses = poses_init + (poses_final - poses_init) * (i / total)
    global_translation = global_translation_init + (global_translation_final - global_translation_init) * (i / total)
    results = f_trans(rays, poses, mano_layer, global_translation=global_translation)
    torch.cuda.synchronize()
    img_pred = results['rgb_fine'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
    img_pred_ = (img_pred*255).astype(np.uint8)
    img_path = f'./figs/grasp_mug_4/{i}.png'
    imageio.imwrite(img_path, img_pred_)

for i in tqdm(range(40)):
    poses = poses_init + (poses_final - poses_init) * ((i + 1 + total) / total)
    global_translation = global_translation_init + (global_translation_final - global_translation_init) * ((i + 1 + total) / total)
    results = f_trans(rays, poses, mano_layer, global_translation=global_translation)
    torch.cuda.synchronize()
    img_pred = results['rgb_fine'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
    img_pred_ = (img_pred*255).astype(np.uint8)
    img_path = f'./figs/grasp_mug_4/{i + total + 1}.png'
    imageio.imwrite(img_path, img_pred_)

100%|██████████| 81/81 [21:14<00:00, 15.73s/it]
100%|██████████| 40/40 [10:23<00:00, 15.59s/it]


In [5]:
import imageio
from tqdm import tqdm

total = 50

for i in tqdm(range(total)):
    poses = poses_init + torch.randn_like(poses_init) * 0.1
    global_translation = global_translation_init + torch.randn_like(global_translation_init) * 0.1
    results = f_trans(rays, poses, mano_layer, global_translation=global_translation)
    torch.cuda.synchronize()
    img_pred = results['rgb_fine'].view(img_wh[1], img_wh[0], 3).cpu().numpy()
    img_pred_ = (img_pred*255).astype(np.uint8)
    img_path = f'./figs/grasp_mug_3/{i}.png'
    imageio.imwrite(img_path, img_pred_)

100%|██████████| 50/50 [12:20<00:00, 14.81s/it]
