In [2]:
"""
    a test script for box-shape reconstruction
"""

import os
import sys
import shutil
import numpy as np
import torch
import utils
import vis_utils_layout as vis_utils
from data_layout import LayoutDataset, Tree
import model_layout as model
from random import shuffle

sys.setrecursionlimit(5000) # this code uses recursion a lot for code simplicity

# how many shapes to evaluate (the top-K in test.txt)
num_recon = 50

path = '/home/weiran/Projects/RvNN-Layout/GT-Layout/magazine/logs/magazine_2.5K_test_emb'
cheakpoint = ''

# load train config
conf = torch.load(path + '/conf.pth')

# load object category information
Tree.load_category_info(conf.category)
conf.device = 'cuda:1'

# set up device
device = torch.device(conf.device)
print(f'Using device: {conf.device}')

# check if eval results already exist. If so, delete it. 
out_dir = path + '/recon'
if os.path.exists(out_dir):
    # response = input('result directory %s exists, overwrite? (y/n) ' % out_dir)
    # if response != 'y':
    #     sys.exit()
    shutil.rmtree(out_dir)

# create a new directory to store eval results
os.mkdir(out_dir)

# create models
# we disable probabilistic because we do not need to jitter the decoded z during inference
encoder = model.RecursiveEncoder(conf, variational=True, probabilistic=True)
decoder = model.RecursiveDecoder(conf)

# load the pretrained models
print('Loading ckpt pretrained_encoder.pth')
data_to_restore = torch.load(path + '/ckpts/' + cheakpoint + 'net_encoder.pth')
encoder.load_state_dict(data_to_restore, strict=True)
print('DONE\n')
print('Loading ckpt pretrained_decoder.pth')
data_to_restore = torch.load(path + '/ckpts/' + cheakpoint + 'net_decoder.pth')
decoder.load_state_dict(data_to_restore, strict=True)
print('DONE\n')

# send to device
encoder.to(device)
decoder.to(device)

# set models to evaluation mode
encoder.eval()
decoder.eval()

data_path = '/home/weiran/Projects/RvNN-Layout/data/magazine-ours/magazine_0417_2.5K/'

# read test.txt
with open(data_path + 'train.txt', 'r') as fin:
    data_list = [l.rstrip() for l in fin.readlines()]

# shuffle(data_list)

num_recon = 300

# reconstruct shapes
with torch.no_grad():
    for i in range(num_recon):
        print(f'Reconstructing {i} / {num_recon} ...')

        # load the gt data as the input
        obj = LayoutDataset.load_object(data_path + data_list[i] + '.json')
        obj.to(device)

        # feed through the encoder to get a code z
        # root_code = encoder.encode_structure(obj)
        root_code_and_kld = encoder.encode_structure(obj)
        root_code = root_code_and_kld[:, :conf.feature_size]

        # infer through the decoder to get the reconstructed output
        # set maximal tree depth to conf.max_tree_depth
        obj_arr = decoder.decode_structure(z=root_code, max_depth=conf.max_tree_depth)
        obj_arr.get_arrbox()

        # output the hierarchy
        with open(os.path.join(out_dir, data_list[i] + '_GT.txt'), 'w') as fout:
            fout.write(str(obj)+'\n\n')
            
        with open(os.path.join(out_dir, data_list[i] + '_PRED.txt'), 'w') as fout:
            fout.write(str(obj_arr)+'\n\n')

        # output the assembled box-shape
        vis_utils.draw_partnet_objects([obj], \
            object_names=['GT'], leafs_only=True, \
            sem_colors_filename='./part_colors_magazine.txt', figsize=(5, 6), \
            out_fn=os.path.join(out_dir, data_list[i] + '_GT.png'))
        
        vis_utils.draw_partnet_objects([obj], \
            object_names=['GT'], leafs_only=True, \
            sem_colors_filename='./part_colors_magazine.txt', figsize=(5, 6), \
            out_fn=os.path.join(out_dir, data_list[i] + '_GT.svg'))
        
        vis_utils.draw_partnet_objects([obj_arr], \
            object_names=['PRED'], leafs_only=True, \
            sem_colors_filename='./part_colors_magazine.txt', figsize=(5, 6), \
            out_fn=os.path.join(out_dir, data_list[i] + '_PRED.png'))

        vis_utils.draw_partnet_objects([obj_arr], \
            object_names=['PRED'], leafs_only=True, \
            sem_colors_filename='./part_colors_magazine.txt', figsize=(5, 6), \
            out_fn=os.path.join(out_dir, data_list[i] + '_PRED.svg'))

Using device: cuda:1
Loading ckpt pretrained_encoder.pth
DONE

Loading ckpt pretrained_decoder.pth
DONE

Reconstructing 0 / 300 ...
Reconstructing 1 / 300 ...
Reconstructing 2 / 300 ...
Reconstructing 3 / 300 ...
Reconstructing 4 / 300 ...
Reconstructing 5 / 300 ...
Reconstructing 6 / 300 ...
Reconstructing 7 / 300 ...
Reconstructing 8 / 300 ...
Reconstructing 9 / 300 ...
Reconstructing 10 / 300 ...
Reconstructing 11 / 300 ...
Reconstructing 12 / 300 ...
Reconstructing 13 / 300 ...
Reconstructing 14 / 300 ...
Reconstructing 15 / 300 ...
Reconstructing 16 / 300 ...
Reconstructing 17 / 300 ...
Reconstructing 18 / 300 ...
Reconstructing 19 / 300 ...
Reconstructing 20 / 300 ...
Reconstructing 21 / 300 ...
Reconstructing 22 / 300 ...
Reconstructing 23 / 300 ...
Reconstructing 24 / 300 ...
Reconstructing 25 / 300 ...
Reconstructing 26 / 300 ...
Reconstructing 27 / 300 ...
Reconstructing 28 / 300 ...
Reconstructing 29 / 300 ...
Reconstructing 30 / 300 ...
Reconstructing 31 / 300 ...
Reconstru