In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pyrender
import os
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from torch.autograd import Variable
from torchsummary import summary
import skimage.measure as sk

import h5py

import time
import pymesh
import trimesh

#%matplotlib notebook

In [3]:
device = torch.device("cpu")

# reproducible.
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
torch.manual_seed(0)

<torch._C.Generator at 0x7f4ebfa956b0>

In [4]:
device

device(type='cpu')

In [5]:
class ChairDepthDataset(Dataset):
    
    def __init__(self, h5_file):
        
        self.hf = h5py.File(h5_file, 'r')
        self.keys = list(self.hf.keys())
        
        self.to_tensor = transforms.ToTensor()

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        
        group = self.hf[self.keys[idx]]
        
        model_id = str(self.keys[idx])
        depth_img = self.to_tensor(Image.fromarray(np.array(group['depth_img'])))
        azimuth = float(group['azimuth'][()])
        elevation = float(group['elevation'][()])
        distance = float(group['distance'][()])
        target_vox = torch.tensor(group['target_vox'], dtype=torch.float)
        
        sample = {'model_id': model_id,
                  'depth_img': depth_img,
                  'azimuth': azimuth,
                  'elevation': elevation,
                  'distance': distance,
                  'target_vox': target_vox, 
                 }

        return sample

In [6]:
data_file = '/home/ankbzpx/datasets/ShapeNet/ShapeNetRenderingh5_v1/03001627/data_test_rescale.h5'

test_depth_dataset = ChairDepthDataset(data_file)

In [7]:
# plot vox or img
from vis_utils import plotFromVoxels, plotImg, plot_image_list
# generate with continuous model
from render_utils import generate_mesh, get_relative_transform_matrix
# render mesh with pyrender
from render_utils import render
# mesh from ground truth vox
from render_utils import RotateAlongAxis
# transfrom sdf
from render_utils import get_transformed_indices, sdf2Voxel, get_meshgrid, get_transformed_meshgrid, get_relative_transformed_vox
# get cd, emd, iou from 2 pymesh
from render_utils import get_test_results

In [8]:
# Pre-trained Model path
discrete_encoder_path = 'discrete_encoder.pth'
mapping_path = 'mapping.pth'
discrete_decoder_path = 'discrete_decoder.pth'
unet_path = 'con_unet_full.pth'
continuous_model_path = 'continuous_model.pth'


from models import Discrete_encoder, Mapping, Discrete_decoder, Conditional_UNET, Continuous

####################
# Discrete Encoder #
####################

discrete_encoder = Discrete_encoder().to(device)
discrete_encoder.load_state_dict(torch.load(discrete_encoder_path))
discrete_encoder.eval()

for child in discrete_encoder.children():
    for param in child.parameters():
        param.requires_grad = False


# ###########
# # Mapping #
# ###########


mapping = Mapping().to(device)
mapping.load_state_dict(torch.load(mapping_path))
mapping.eval()

for child in mapping.children():
    for param in child.parameters():
        param.requires_grad = False

# ####################
# # Discrete Decoder #
# ####################

discrete_decoder = Discrete_decoder().to(device)
discrete_decoder.load_state_dict(torch.load(discrete_decoder_path))
discrete_decoder.eval()
########
# UNET #
########

# pre-trained model is loaded within the model
unet = Conditional_UNET(unet_path).to(device)

unet.eval()

for child in unet.children():
    for param in child.parameters():
        param.requires_grad = False


        
##############
# Continuous #
##############

continuous = Continuous().to(device)
continuous.load_state_dict(torch.load(continuous_model_path))
continuous.eval()

for child in continuous.children():
    for param in child.parameters():
        param.requires_grad = False

In [9]:
voxsize = 32
D2R = np.pi/180.0

In [10]:
# custom depth consistency loss

from torch.autograd import Function
from torch.autograd.function import once_differentiable

class DepthConsistencyLoss(Function):

    @staticmethod
    def forward(ctx, out_vox, depth_index, close_index):
        
        # out_vox shape n x 1 x 32 x 32 x 32
        # depth_index shape n x 32 x 32 x 32
        # close_index shape n x 32 x 32 x 32
        
        ctx.save_for_backward(out_vox)
        
        vox = out_vox.squeeze(1)
        ctx.depth_index = depth_index
        ctx.close_index = close_index
        
        loss = torch.zeros_like(depth_index, dtype=torch.float).to(device)
        
        #bce form
        loss[depth_index.bool()] = -torch.log(vox[depth_index.bool()])
        loss[close_index.bool()] = -torch.log(1-vox[close_index.bool()])
        
        return torch.sum(loss) / (torch.sum(depth_index) + torch.sum(close_index))
        
    @staticmethod
    @once_differentiable
    def backward(ctx, grad_output):
        
        out_vox, = ctx.saved_tensors
        
        vox = out_vox.squeeze(1)
        grad = torch.zeros_like(ctx.depth_index, dtype=torch.float).to(device)
        
        #bce form
        grad[ctx.depth_index.bool()] = -1/vox[ctx.depth_index.bool()]
        grad[ctx.close_index.bool()] = 1/(1-vox[ctx.close_index.bool()])
        
        return grad.unsqueeze(1), None, None, None

In [11]:
from refine_utils import get_radius, get_depth_close_idx

In [12]:
from render_utils import mesh_from_voxel

In [13]:
%matplotlib notebook

mesh_grid_res = 64
mesh_batch_size = 32

def depth_refine_three_imgs(idx, lr = 1e-3, iter_count = 30):
    
    data_from = test_depth_dataset[24*idx]
    depth_img_from = data_from['depth_img'].unsqueeze(0)
    target_vox_from = data_from['target_vox']
    distance_from = data_from['distance']
    azimuth_from = data_from['azimuth']
    elevation_from = data_from['elevation']
    
    #plotFromVoxels(target_vox_from[0])
    
    ########### initial depth ##########
    
    radius = get_radius(depth_img_from, distance_from)
    depth_index_from, close_index_from = get_depth_close_idx(depth_img_from, distance_from, radius)
    
    ########### gt mesh ##########
    
    model_dir = path_model + test_depth_dataset[24*idx]['model_id'].split('_')[0] + '/model.obj'
    mesh_py = pymesh.load_mesh(model_dir)
    transformed_vertices = get_transformed_indices(mesh_py.vertices, azimuth_from, elevation_from, 1)
    gt_radius = np.max(np.linalg.norm(transformed_vertices, axis = 1))
    mesh_gt = pymesh.form_mesh(transformed_vertices/gt_radius, mesh_py.faces)
    
    #plotImg(depth_img_from[0, 0].detach().cpu().numpy())
    
    z = discrete_encoder(depth_img_from.to(device))
    w = mapping(z)
    
    w = w.clone().detach().requires_grad_(True)
    
    w.requires_grad = True
    
    optimizer = torch.optim.SGD([w], lr=lr, momentum=0.8)
    
    out_vox = torch.sigmoid(discrete_decoder(w))
    #plotFromVoxels(out_vox[0, 0].detach().cpu().numpy() > 0.5)
    
    mesh_vox_before = mesh_from_voxel(out_vox[0, 0])
    
    mesh_before = generate_mesh(continuous, unet, out_vox, z, device, vox_res = 32, grid_res = mesh_grid_res, batch_size = mesh_batch_size, 
                                azimuth = 0, elevation = 0, isosurface = 0.0)
    
    ran = np.random.choice(np.arange(1, 24), 2)
    
    ########### first depth ##########
    
    idx_to_1 = 24*idx + ran[0]
    data_to_1 = test_depth_dataset[idx_to_1]
    depth_img_to_1 = data_to_1['depth_img'].unsqueeze(0)
    distance_to_1 = data_to_1['distance']
    azimuth_to_1 = data_to_1['azimuth']
    elevation_to_1 = data_to_1['elevation']
    depth_index_to_1, close_index_to_1 = get_depth_close_idx(depth_img_to_1, distance_to_1, radius)
    
    ########### second depth ##########
    
    idx_to_2 = 24*idx + ran[1]
    data_to_2 = test_depth_dataset[idx_to_2]
    depth_img_to_2 = data_to_2['depth_img'].unsqueeze(0)
    distance_to_2 = data_to_2['distance']
    azimuth_to_2 = data_to_2['azimuth']
    elevation_to_2 = data_to_2['elevation']
    depth_index_to_2, close_index_to_2 = get_depth_close_idx(depth_img_to_2, distance_to_2, radius)
    
    # Optimize with depth consistency loss
    for i in range(iter_count):
        
        optimizer.zero_grad()
        
        
        relative_transformed_vox_1 = get_relative_transformed_vox(out_vox, -azimuth_from, -elevation_from, azimuth_to_1, elevation_to_1, 
                                                            device, voxsize = 32, align_mode = 'zeros')
        
        relative_transformed_vox_2 = get_relative_transformed_vox(out_vox, -azimuth_from, -elevation_from, azimuth_to_2, elevation_to_2, 
                                                            device, voxsize = 32, align_mode = 'zeros')
        
        loss_0 = DepthConsistencyLoss.apply(out_vox, depth_index_from.unsqueeze(0).to(device), close_index_from.unsqueeze(0).to(device))
        loss_1 = DepthConsistencyLoss.apply(relative_transformed_vox_1, depth_index_to_1.unsqueeze(0).to(device), close_index_to_1.unsqueeze(0).to(device))
        loss_2 = DepthConsistencyLoss.apply(relative_transformed_vox_2, depth_index_to_2.unsqueeze(0).to(device), close_index_to_2.unsqueeze(0).to(device))
        
        loss = 2*loss_0 + loss_1 + loss_2
        
#         if i == 0:
#             print("Initial loss: ", loss.item())
#         else:
#             print(loss.item())
        
        loss.backward(retain_graph=True)
        optimizer.step()
        
        out_vox = torch.sigmoid(discrete_decoder(w))
        
    #plotFromVoxels(out_vox[0, 0].detach().cpu().numpy() > 0.5)    
    
    mesh_after = generate_mesh(continuous, unet, out_vox, z, device, vox_res = 32, grid_res = mesh_grid_res, batch_size = mesh_batch_size, 
                               azimuth = 0, elevation = 0, isosurface = 0.0)
    
    mesh_vox_after = mesh_from_voxel(out_vox[0, 0])
    
    return mesh_gt, mesh_vox_before, mesh_vox_after, mesh_before, mesh_after

In [14]:
def normalize_mesh(mesh_py, vox_size = 64):
    
    normalized_vertices = mesh_py.vertices - (vox_size/2)
    normalized_vertices /= np.max(np.linalg.norm(normalized_vertices, axis = 1))
    
    return pymesh.form_mesh(normalized_vertices, mesh_py.faces)

In [15]:
path_model = '/home/ankbzpx/datasets/ShapeNet/ShapeNetCore.v1/03001627/'

cd_gt_v_b_list = []
emd_gt_v_b_list = []
iou_gt_v_b_list = []

cd_gt_v_a_list = []
emd_gt_v_a_list = []
iou_gt_v_a_list = []

cd_gt_m_b_list = []
emd_gt_m_b_list = []
iou_gt_m_b_list = []

cd_gt_m_a_list = []
emd_gt_m_a_list = []
iou_gt_m_a_list = []

failed_idx = []

for idx in range(int(len(test_depth_dataset)/24)):
    print(idx)
    start_time = time.time()
    
    mesh_gt, mesh_vox_before, mesh_vox_after, mesh_before, mesh_after = depth_refine_three_imgs(idx)
    
    if mesh_gt is None or mesh_vox_before is None or mesh_vox_after is None or mesh_before is None or mesh_after is None:
        print('Failed case')
        failed_idx.append(idx)
        continue
        
    #mesh_gt = normalize_mesh(mesh_gt)
    mesh_vox_before = normalize_mesh(mesh_vox_before, 64)
    mesh_vox_after = normalize_mesh(mesh_vox_after, 64)
    mesh_before = normalize_mesh(mesh_before, mesh_grid_res)
    mesh_after = normalize_mesh(mesh_after, mesh_grid_res)
    
    cd_gt_v_b, emd_gt_v_b, iou_gt_v_b = get_test_results(mesh_gt, mesh_vox_before)
    cd_gt_v_a, emd_gt_v_a, iou_gt_v_a = get_test_results(mesh_gt, mesh_vox_after)
    cd_gt_m_b, emd_gt_m_b, iou_gt_m_b = get_test_results(mesh_gt, mesh_before)
    cd_gt_m_a, emd_gt_m_a, iou_gt_m_a = get_test_results(mesh_gt, mesh_after)
    
    cd_gt_v_b_list.append(cd_gt_v_b)
    emd_gt_v_b_list.append(emd_gt_v_b)
    iou_gt_v_b_list.append(iou_gt_v_b)
    
    cd_gt_v_a_list.append(cd_gt_v_a)
    emd_gt_v_a_list.append(emd_gt_v_a)
    iou_gt_v_a_list.append(iou_gt_v_a)
    
    cd_gt_m_b_list.append(cd_gt_m_b)
    emd_gt_m_b_list.append(emd_gt_m_b)
    iou_gt_m_b_list.append(iou_gt_m_b)
    
    cd_gt_m_a_list.append(cd_gt_m_a)
    emd_gt_m_a_list.append(emd_gt_m_a)
    iou_gt_m_a_list.append(iou_gt_m_a)
    
    print("--- %s seconds ---" % (time.time() - start_time))

0
--- 25.606603145599365 seconds ---
1
--- 29.813942193984985 seconds ---
2
--- 34.64919328689575 seconds ---
3
--- 35.95049452781677 seconds ---
4
--- 33.853222370147705 seconds ---
5
--- 34.43232274055481 seconds ---
6
--- 34.434258222579956 seconds ---
7
--- 35.18905544281006 seconds ---
8
--- 34.008018493652344 seconds ---
9
--- 34.977938413619995 seconds ---
10
--- 37.29167103767395 seconds ---
11
--- 35.17013669013977 seconds ---
12
--- 35.857972383499146 seconds ---
13
--- 34.93836808204651 seconds ---
14
--- 34.80594062805176 seconds ---
15
--- 34.51568865776062 seconds ---
16
--- 36.01037240028381 seconds ---
17
--- 37.68583154678345 seconds ---
18
--- 35.930830240249634 seconds ---
19
--- 35.56439423561096 seconds ---
20
--- 37.737003564834595 seconds ---
21
--- 36.48070287704468 seconds ---
22
--- 36.148614168167114 seconds ---
23
--- 36.701833963394165 seconds ---
24
--- 36.406747579574585 seconds ---
25
--- 36.154290437698364 seconds ---
26
--- 36.70730900764465 seconds --

KeyboardInterrupt: 

In [None]:
print('--- vox before ---')
print('Chamfer Distacne: m_', np.mean(cd_gt_v_b_list), ' s_', np.std(cd_gt_v_b_list))
print()
print('Earth Movers Distance: m_', np.mean(emd_gt_v_b_list), ' s_', np.std(emd_gt_v_b_list))
print()
print('Intersection over Union: m_', np.mean(iou_gt_v_b_list), ' s_', np.std(iou_gt_v_b_list))
print()

print('--- vox after ---')
print('Chamfer Distacne: m_', np.mean(cd_gt_v_a_list), ' s_', np.std(cd_gt_v_a_list))
print()
print('Earth Movers Distance: m_', np.mean(emd_gt_v_a_list), ' s_', np.std(emd_gt_v_a_list))
print()
print('Intersection over Union: m_', np.mean(iou_gt_v_a_list), ' s_', np.std(iou_gt_v_a_list))
print()

print('--- mesh before ---')
print('Chamfer Distacne: m_', np.mean(cd_gt_m_b_list), ' s_', np.std(cd_gt_m_b_list))
print()
print('Earth Movers Distance: m_', np.mean(emd_gt_m_b_list), ' s_', np.std(emd_gt_m_b_list))
print()
print('Intersection over Union: m_', np.mean(iou_gt_m_b_list), ' s_', np.std(iou_gt_m_b_list))
print()

print('--- mesh after ---')
print('Chamfer Distacne: m_', np.mean(cd_gt_m_a_list), ' s_', np.std(cd_gt_m_a_list))
print()
print('Earth Movers Distance: m_', np.mean(emd_gt_m_a_list), ' s_', np.std(emd_gt_m_a_list))
print()
print('Intersection over Union m_', np.mean(iou_gt_m_a_list), ' s_', np.std(iou_gt_m_a_list))
print()

In [None]:
render(mesh_gt)

In [34]:
render(mesh_vox_before)

In [35]:
render(mesh_vox_after)

In [36]:
render(mesh_before)

In [37]:
render(mesh_after)

In [55]:
int(len(test_depth_dataset)/24)

1311