In [1]:
import numpy as np
import random
import torch
import importlib
import os
import argparse

import pathlib
import sys
sys.path.append('../')
from models.network import AutoEncoder

from data.data import ImNetImageSamples, ImNetSamples
from torch.multiprocessing import Pool, Process, set_start_method

from evaluation.eval_utils import sample_points_polygon_vox64_njit
from utils.other_utils import write_ply_point_normal

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config_path = './../pretrain/class_pretrain/phase_2_model/config.py' 
data_path = './../data/all_vox256_img_with_classes/all_vox256_img_test.hdf5' 
input_type = 'voxels' 
obj_txt_file = './../data/all_vox256_img_with_classes/all_vox256_img_test.txt'
network_path = './../pretrain/class_pretrain/phase_2_model/model_epoch_2_310.pth'
save_folder = './test_eval'

In [3]:
spec = importlib.util.spec_from_file_location('*', config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)

## dataload
### create dataset
if input_type == 'image':
    samples = ImNetImageSamples(
        data_path=args.data_path, 
        label_txt_path=args.obj_txt_file,
        image_idx=23, # Last image, stick to BSP-Net calc
        use_depth=hasattr(config, 'use_depth') and config.use_depth,
        image_preferred_color_space=config.image_preferred_color_space if hasattr(config, 'image_preferred_color_space') else 1
    )
elif input_type == 'voxels':
    # TODO: In some cases sample_voxel_size could have different size, add to args - not now cause not needed and dont used
    samples = ImNetSamples(data_path=data_path, sample_voxel_size=64, label_txt_path=obj_txt_file)
else:
    raise Exception(f'Unknown input type {input_type}. ')

In [4]:
samples.obj_paths[1]

'02691156/d18f2aeae4146464bd46d022fd7d80aa'

In [5]:
## loading index
sample_interval = 1
resolution = 64
max_batch = 20000 if input_type == 'image' else 100000
thershold = 0.01
with_surface_point = True # TODO: Is it needed here?

In [6]:
def get_input_data(samples, i, num_input_data_aggregation, aggregate_embedding, view_use_indx_list):
    if aggregate_embedding:
        if num_input_data_aggregation is not None and num_input_data_aggregation == -1 and view_use_indx_list is None:
            indx_view_iterator = range(samples.view_num)
        elif view_use_indx_list is not None and len(view_use_indx_list) > 0:
            indx_view_iterator = view_use_indx_list
        else:
            return [samples[i][0][0] for _ in range(num_input_data_aggregation)]
        gathered_data_list = []
        for indx_view in indx_view_iterator:
            samples.image_idx = int(indx_view)
            gathered_data_list.append(
                samples[i][0][0]
            )
        # Can be set to None, we dont care here about value 23 
        samples.image_idx = None
        return gathered_data_list

    return samples[i][0][0]
generate_args = [
    (
        get_input_data(samples, i, -1, False, None), 
        os.path.join(save_folder, samples.obj_paths[i]), 
        resolution, max_batch, (-0.5, 0.5), 
        thershold, with_surface_point
    ) 
    for i in range(10) if i % sample_interval == 0
]

In [7]:
device_id = 3
torch.cuda.set_device(device_id)
spec = importlib.util.spec_from_file_location('*', config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)

network_state_dict = torch.load(network_path)
network_state_dict, is_old_style_weights = AutoEncoder.process_state_dict(network_state_dict, type = 1)
if is_old_style_weights:
    config = AutoEncoder.fix_old_weights_config(config)

network = AutoEncoder(config=config).cuda(device_id)
network.load_state_dict(network_state_dict)
network.eval()

AutoEncoder(
  (encoder): CNN3D(
    (conv_1): Conv3d(1, 48, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (conv_2): Conv3d(48, 96, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (conv_3): Conv3d(96, 192, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (conv_4): Conv3d(192, 384, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (conv_5): Conv3d(384, 384, kernel_size=(4, 4, 4), stride=(1, 1, 1))
  )
  (decoder): FlowDecoder(
    (ode_layers): ModuleList(
      (0): ODEfunc(
        (ode_net): ODE_ResNet(
          (context_linear): Linear(in_features=129, out_features=256, bias=True)
          (coordinate_linear): Linear(in_features=3, out_features=256, bias=True)
          (last_linear): Linear(in_features=256, out_features=3, bias=True)
          (layers): ModuleList(
            (0): Linear(in_features=256, out_features=256, bias=True)
            (1): Linear(in_features=256, out_features=256, bias=True)
          )
   

In [8]:
indx = 1

input_data, store_file_folder_path, resolution, max_batch, space_range, thershold, with_surface_point = generate_args[indx]
store_file_path = os.path.join(store_file_folder_path, 'obj.ply')
store_file_path

'./test_eval/02691156/d18f2aeae4146464bd46d022fd7d80aa/obj.ply'

In [9]:
input_data = torch.from_numpy(input_data[0] if isinstance(input_data, list) else input_data).float().cuda(device_id)
result = network.save_bsp_deform(
    inputs=input_data, file_path=store_file_path, resolution=resolution, max_batch=max_batch,
    space_range=space_range, thershold_1=thershold, embedding=None,
    return_voxel_and_values=True
)

In [10]:
from utils.ply_utils import read_ply_point
from evaluation.eval import calculate_cd, calculate_normal_consistency

In [11]:
vertices_pd = read_ply_point(store_file_path.replace('obj.ply', 'obj_deformed.ply'))
vertices_gt = samples.data_points[indx][samples.data_values[indx][:, 0] > 1e-4]
vertices_pd.shape, vertices_gt.shape

((23991, 3), (1258, 3))

In [12]:
cd_calc = calculate_cd(vertices_pd, vertices_gt)
cd_calc

0.00013360625962377526

In [46]:
if hasattr(config, 'sample_class') and config.sample_class:
    (vertices, polygons, vertices_deformed, polygons_deformed, 
        embedding, vertices_convex, bsp_convex_list, 
        predicted_class, convex_predictions_sum, point_value_prediction) = result
    np.save(os.path.join(store_file_folder_path, 'predicted_class_logits.npy'), predicted_class)
else:
    (vertices, polygons, vertices_deformed, polygons_deformed, 
        embedding, vertices_convex, bsp_convex_list, 
        convex_predictions_sum, point_value_prediction) = result

In [48]:
vertices.shape, vertices_deformed.shape

((1425, 3), (1425, 3))