In [1]:
"""
    a test script for box-shape free generation
"""

import os
import sys
import shutil
import numpy as np
import torch
import utils
import vis_utils_layout as vis_utils
from data_layout import LayoutDataset, Tree
import model_layout as model

sys.setrecursionlimit(5000) # this code uses recursion a lot for code simplicity

# number of shapes to generate
num_gen = 300

path = '/home/weiran/Projects/RvNN-Layout/GT-Layout/magazine/logs/magazine_2.5K_test_emb'
checkpoint = ''

# load train config
conf = torch.load(path + '/conf.pth')

# load object category information
Tree.load_category_info(conf.category)

# set up device
conf.device = 'cuda:0'
device = torch.device(conf.device)
print(f'Using device: {conf.device}')

# check if eval results already exist. If so, delete it. 
out_dir = path + '/generation'
if os.path.exists(out_dir):
    # response = input('result directory %s exists, overwrite? (y/n) ' % out_dir)
    # if response != 'y':
    #     sys.exit()
    shutil.rmtree(out_dir)

# create a new directory to store eval results
os.mkdir(out_dir)

# create models
decoder = model.RecursiveDecoder(conf)

# load the pretrained models
print('Loading ckpt pretrained_decoder.pth')
data_to_restore = torch.load(path + '/ckpts/' + checkpoint + 'net_decoder.pth')
decoder.load_state_dict(data_to_restore, strict=True)
print('DONE\n')

# send to device
decoder.to(device)

# set models to evaluation mode
decoder.eval()

Using device: cuda:0
Loading ckpt pretrained_decoder.pth
DONE



RecursiveDecoder(
  (box_decoder): BoxDecoder(
    (xy): Linear(in_features=256, out_features=2, bias=True)
    (size): Linear(in_features=256, out_features=2, bias=True)
  )
  (sem_decoder): SemDecoder(
    (decoder): Linear(in_features=256, out_features=10, bias=True)
  )
  (leaf_decoder): LeafDecoder(
    (decoder): Linear(in_features=256, out_features=256, bias=True)
  )
  (vertical_decoder): BranchDecoder(
    (mlp_parent_1): Linear(in_features=256, out_features=1280, bias=True)
    (mlp_parent_2): Linear(in_features=1280, out_features=1280, bias=True)
    (mlp_exists): Linear(in_features=256, out_features=1, bias=True)
    (mlp_arrange): Linear(in_features=256, out_features=20, bias=True)
    (mlp_sem): Linear(in_features=256, out_features=10, bias=True)
    (mlp_child): Linear(in_features=256, out_features=256, bias=True)
  )
  (horizontal_decoder): BranchDecoder(
    (mlp_parent_1): Linear(in_features=256, out_features=1280, bias=True)
    (mlp_parent_2): Linear(in_features=128

In [2]:
# generate shapes
with torch.no_grad():
    for i in range(num_gen):
        print(f'Generating {i}/{num_gen} ...')

        # get a Gaussian noise
        code = torch.randn(1, conf.feature_size).cuda()
        
        # infer through the model to get the generated hierarchy
        # set maximal tree depth to conf.max_tree_depth
        obj_arr = decoder.decode_structure(z=code, max_depth=conf.max_tree_depth)

        obj_arr.get_arrbox()
        
        # output the hierarchy
        with open(os.path.join(out_dir, 'gen-%03d.txt'%i), 'w') as fout:
            fout.write(str(obj_arr)+'\n\n')

        # output the assembled box-shape
        vis_utils.draw_partnet_objects([obj_arr],\
                object_names=['GENERATION'], \
                figsize=(5, 6), out_fn=os.path.join(out_dir, 'gen-%03d.jpg'%i),\
                leafs_only=True,sem_colors_filename='./part_colors_magazine.txt')


Generating 0/300 ...
Generating 1/300 ...
Generating 2/300 ...
Generating 3/300 ...
Generating 4/300 ...
Generating 5/300 ...
Generating 6/300 ...
Generating 7/300 ...
Generating 8/300 ...
Generating 9/300 ...
Generating 10/300 ...
Generating 11/300 ...
Generating 12/300 ...
Generating 13/300 ...
Generating 14/300 ...
Generating 15/300 ...
Generating 16/300 ...
Generating 17/300 ...
Generating 18/300 ...
Generating 19/300 ...
Generating 20/300 ...
Generating 21/300 ...
Generating 22/300 ...
Generating 23/300 ...
Generating 24/300 ...
Generating 25/300 ...
Generating 26/300 ...
Generating 27/300 ...
Generating 28/300 ...
Generating 29/300 ...
Generating 30/300 ...
Generating 31/300 ...
Generating 32/300 ...
Generating 33/300 ...
Generating 34/300 ...
Generating 35/300 ...
Generating 36/300 ...
Generating 37/300 ...
Generating 38/300 ...
Generating 39/300 ...
Generating 40/300 ...
Generating 41/300 ...
Generating 42/300 ...
Generating 43/300 ...
Generating 44/300 ...
Generating 45/300 ..