In [2]:
import logging
import os
import sys
import importlib
import argparse
import munch
import yaml
from utils.vis_utils import plot_single_pcd
from utils.train_utils import *
from dataset import ShapeNetH5
config_path = 'cfgs/vrcnet.yaml'
args = munch.munchify(yaml.safe_load(open(config_path)))
exp_name = os.path.basename(args.load_model)
log_dir = os.path.dirname(args.load_model)
logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(log_dir, 'rot_test.log')),
                                                  logging.StreamHandler(sys.stdout)])

In [16]:
dataset_test = ShapeNetH5(train=False, npoints=args.num_points)
dataloader_test = torch.utils.data.DataLoader(dataset_test, batch_size=args.batch_size,
                                              shuffle=False, num_workers=int(args.workers))
dataset_length = len(dataset_test)
logging.info('Length of rotated test dataset:%d', len(dataset_test))

(41600, 2048, 3)
(1600, 2048, 3)
(41600,)
INFO:root:Length of rotated test dataset:41600


In [4]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'

In [5]:
model_module = importlib.import_module('.%s' % args.model_name, 'models')
net = torch.nn.DataParallel(model_module.Model(args))
net.cuda()
net.module.load_state_dict(torch.load(args.load_model)['net_state_dict'])
logging.info("%s's previous weights loaded." % args.model_name)

Loaded compiled 3D CUDA chamfer distance
INFO:root:vrcnet's previous weights loaded.


In [17]:
# azimuthal_angle = 90
import numpy as np
# angle coordinates are askewed
rotation_matrix = np.array(((0,0,1),(0,1,0),(-1,0,0)))
dataset_test.gt_data = dataset_test.gt_data @ rotation_matrix
print(dataset_test.gt_data.shape)

(1600, 2048, 3)


In [18]:
dataset_test.input_data = dataset_test.input_data @ rotation_matrix
print(dataset_test.input_data.shape)

(41600, 2048, 3)


In [36]:
metrics = ['cd_p', 'cd_t', 'f1']
test_loss_meters = {m: AverageValueMeter() for m in metrics}
test_loss_cat = torch.zeros([16, 3], dtype=torch.float32).cuda()
cat_num = torch.ones([16, 1], dtype=torch.float32).cuda() * 150
cat_name = ['airplane', 'cabinet', 'car', 'chair', 'lamp', 'sofa', 'table', 'vessel',
            'bed', 'bench', 'bookshelf', 'bus', 'guitar', 'motorbike', 'pistol', 'skateboard']
idx_to_plot = [i for i in range(0, 41600, 75)]
logging.info('Testing...')
if args.save_vis:
    save_gt_path = os.path.join(log_dir, 'rot_pics', 'gt')
    save_partial_path = os.path.join(log_dir, 'rot_pics', 'partial')
    save_completion_path = os.path.join(log_dir, 'rot_pics', 'completion')
    os.makedirs(save_gt_path, exist_ok=True)
    os.makedirs(save_partial_path, exist_ok=True)
    os.makedirs(save_completion_path, exist_ok=True)

INFO:root:Testing...


In [None]:
with torch.no_grad():
    for i, data in enumerate(dataloader_test):
            
        label, inputs_cpu, gt_cpu = data
        # mean_feature = None

        inputs = inputs_cpu.float().cuda()
        gt = gt_cpu.float().cuda()
        inputs = inputs.transpose(2, 1).contiguous()
        # result_dict = net(inputs, gt, is_training=False, mean_feature=mean_feature)
        result_dict = net(inputs, gt, is_training=False)
        for k, v in test_loss_meters.items():
            v.update(result_dict[k].mean().item())

        for j, l in enumerate(label):
            for ind, m in enumerate(metrics):
                test_loss_cat[int(l), ind] = result_dict[m][int(j)]

        if i % args.step_interval_to_print == 0:
            logging.info('test [%d/%d]' % (i, dataset_length / args.batch_size))

        if args.save_vis:
            for j in range(args.batch_size):
                idx = i * args.batch_size + j
                if idx in idx_to_plot:
                    pic = 'object_%d.png' % idx
                    plot_single_pcd(result_dict['out2'][j].cpu().numpy(), os.path.join(save_completion_path, pic))
                    plot_single_pcd(gt_cpu[j], os.path.join(save_gt_path, pic))
                    plot_single_pcd(inputs_cpu[j].cpu().numpy(), os.path.join(save_partial_path, pic))

INFO:root:test [0/1300]


In [35]:
with torch.no_grad():
    logging.info('Loss per category:')
    category_log = ''
    for i in range(16):
        category_log += '\ncategory name: %s ' % (cat_name[i])
        for ind, m in enumerate(metrics):
            scale_factor = 1 if m == 'f1' else 10000
            category_log += '%s: %f ' % (m, test_loss_cat[i, 0] / cat_num[i] * scale_factor)
    logging.info(category_log)

    logging.info('Overview results:')
    overview_log = ''
    for metric, meter in test_loss_meters.items():
        overview_log += '%s: %f ' % (metric, meter.avg)
    logging.info(overview_log)

INFO:root:Loss per category:
INFO:root:
category name: airplane cd_p: 0.962176 cd_t: 0.962176 f1: 0.000096 
category name: cabinet cd_p: 1.181573 cd_t: 1.181573 f1: 0.000118 
category name: car cd_p: 1.461717 cd_t: 1.461717 f1: 0.000146 
category name: chair cd_p: 1.879329 cd_t: 1.879329 f1: 0.000188 
category name: lamp cd_p: 0.679914 cd_t: 0.679914 f1: 0.000068 
category name: sofa cd_p: 1.779218 cd_t: 1.779218 f1: 0.000178 
category name: table cd_p: 0.672744 cd_t: 0.672744 f1: 0.000067 
category name: vessel cd_p: 1.142574 cd_t: 1.142574 f1: 0.000114 
category name: bed cd_p: 1.338193 cd_t: 1.338193 f1: 0.000134 
category name: bench cd_p: 0.933539 cd_t: 0.933539 f1: 0.000093 
category name: bookshelf cd_p: 0.914230 cd_t: 0.914230 f1: 0.000091 
category name: bus cd_p: 0.785461 cd_t: 0.785461 f1: 0.000079 
category name: guitar cd_p: 0.751776 cd_t: 0.751776 f1: 0.000075 
category name: motorbike cd_p: 0.785553 cd_t: 0.785553 f1: 0.000079 
category name: pistol cd_p: 0.939464 cd_t: 