In [1]:
"""
    a test script for box-shape free generation
"""

import os
import sys
import shutil
import numpy as np
import torch
import utils
import vis_utils_layout as vis_utils
from data_layout import LayoutDataset, Tree
import model_layout as model

sys.setrecursionlimit(5000) # this code uses recursion a lot for code simplicity

# number of shapes to generate
num_gen = 1100

path = '/home/weiran/Code/Layout-CVPR22/1111_mag_contain_tanh/model'
data_path = '/home/weiran/Code/Layout-CVPR22/Data/mag_contain_1110_final/'
checkpoint = '307_'

# load train config
conf = torch.load(path + '/conf.pth')

# load object category information
Tree.load_category_info(conf.category)

# set up device
conf.device = 'cuda:0'
device = torch.device(conf.device)
print(f'Using device: {conf.device}')

# check if eval results already exist. If so, delete it. 
out_dir = '/home/weiran/Code/Layout-CVPR22/eval_output/magazine_new_tanh/freely_generated_shapes'
if os.path.exists(out_dir):
    # response = input('result directory %s exists, overwrite? (y/n) ' % out_dir)
    # if response != 'y':
    #     sys.exit()
    shutil.rmtree(out_dir)

# create a new directory to store eval results
os.mkdir(out_dir)

# create models
decoder = model.RecursiveDecoder(conf)

# load the pretrained models
print('Loading ckpt pretrained_decoder.pth')
data_to_restore = torch.load(path + '/' + checkpoint + 'net_decoder.pth')
decoder.load_state_dict(data_to_restore, strict=True)
print('DONE\n')

# send to device
decoder.to(device)

# set models to evaluation mode
decoder.eval()

Using device: cuda:0
Loading ckpt pretrained_decoder.pth
DONE



RecursiveDecoder(
  (box_decoder): BoxDecoder(
    (xy): Linear(in_features=256, out_features=2, bias=True)
    (size): Linear(in_features=256, out_features=2, bias=True)
  )
  (sem_decoder): SemDecoder(
    (decoder): Linear(in_features=256, out_features=8, bias=True)
  )
  (leaf_decoder): LeafDecoder(
    (decoder): Linear(in_features=256, out_features=256, bias=True)
  )
  (vertical_decoder): BranchDecoder(
    (mlp_parent_1): Linear(in_features=256, out_features=1280, bias=True)
    (mlp_parent_2): Linear(in_features=1280, out_features=1280, bias=True)
    (mlp_exists): Linear(in_features=256, out_features=1, bias=True)
    (mlp_arrange): Linear(in_features=256, out_features=20, bias=True)
    (mlp_sem): Linear(in_features=256, out_features=8, bias=True)
    (mlp_child): Linear(in_features=256, out_features=256, bias=True)
  )
  (horizontal_decoder): BranchDecoder(
    (mlp_parent_1): Linear(in_features=256, out_features=1280, bias=True)
    (mlp_parent_2): Linear(in_features=1280,

In [2]:
# generate shapes
with torch.no_grad():
    for i in range(num_gen):
        print(f'Generating {i}/{num_gen} ...')

        # get a Gaussian noise
        code = torch.randn(1, conf.feature_size).cuda()
        
        # infer through the model to get the generated hierarchy
        # set maximal tree depth to conf.max_tree_depth
        obj_arr = decoder.decode_structure(z=code, max_depth=conf.max_tree_depth)

        obj_arr.get_arrbox()
        
        # output the hierarchy
        with open(os.path.join(out_dir, 'gen-%03d.txt'%i), 'w') as fout:
            fout.write(str(obj_arr)+'\n\n')

        # output the assembled box-shape
        vis_utils.draw_partnet_objects([obj_arr],\
                object_names=['GENERATION'], \
                figsize=(5, 6), out_fn=os.path.join(out_dir, 'gen-%03d.jpg'%i),\
                leafs_only=True,sem_colors_filename='./part_colors_magazine.txt')


Generating 0/1100 ...
Generating 1/1100 ...
Generating 2/1100 ...
Generating 3/1100 ...
Generating 4/1100 ...
Generating 5/1100 ...
Generating 6/1100 ...
Generating 7/1100 ...
Generating 8/1100 ...
Generating 9/1100 ...
Generating 10/1100 ...
Generating 11/1100 ...
Generating 12/1100 ...
Generating 13/1100 ...
Generating 14/1100 ...
Generating 15/1100 ...
Generating 16/1100 ...
Generating 17/1100 ...
Generating 18/1100 ...
Generating 19/1100 ...
Generating 20/1100 ...
Generating 21/1100 ...
Generating 22/1100 ...
Generating 23/1100 ...
Generating 24/1100 ...
Generating 25/1100 ...
Generating 26/1100 ...
Generating 27/1100 ...
Generating 28/1100 ...
Generating 29/1100 ...
Generating 30/1100 ...
Generating 31/1100 ...
Generating 32/1100 ...
Generating 33/1100 ...
Generating 34/1100 ...
Generating 35/1100 ...
Generating 36/1100 ...
Generating 37/1100 ...
Generating 38/1100 ...
Generating 39/1100 ...
Generating 40/1100 ...
Generating 41/1100 ...
Generating 42/1100 ...
Generating 43/1100 ..