In [4]:
"""
    a test script for box-shape free generation
"""

import os
import sys
import shutil
import numpy as np
import torch
import utils
import vis_utils_layout as vis_utils
from data_layout import LayoutDataset, Tree
import model_layout as model

sys.setrecursionlimit(5000) # this code uses recursion a lot for code simplicity

root_dir = '../data/magazine_real_1025_1W/'

num_interp = 11
shape_id1 = '440_layout'
shape_id2 = '760_layout'

# load train config
conf = torch.load('exp_vae/conf.pth')

# load object category information
Tree.load_category_info(conf.category)

# set up device
device = torch.device(conf.device)
print(f'Using device: {conf.device}')

# check if eval results already exist. If so, delete it. 
out_dir = 'exp_vae/globally_interped_shapes_%s_%s' % (shape_id1, shape_id2)

if os.path.exists(out_dir):
    # response = input('result directory %s exists, overwrite? (y/n) ' % out_dir)
    # if response != 'y':
    #     sys.exit()
    shutil.rmtree(out_dir)

# create a new directory to store eval results
os.mkdir(out_dir)

# create models
# we disable probabilistic because we do not need to jitter the decoded z during inference
encoder = model.RecursiveEncoder(conf, variational=True, probabilistic=False)
decoder = model.RecursiveDecoder(conf)

print('Loading ckpt net_encoder.pth')
data_to_restore = torch.load('./exp_vae/ckpts/100_net_encoder.pth')
encoder.load_state_dict(data_to_restore, strict=True)
print('DONE\n')
print('Loading ckpt net_decoder.pth')
data_to_restore = torch.load('./exp_vae/ckpts/100_net_decoder.pth')
decoder.load_state_dict(data_to_restore, strict=True)
print('DONE\n')

# send to device
encoder.to(device)
decoder.to(device)

# set models to evaluation mode
encoder.eval()
decoder.eval()

Using device: cuda:0
Loading ckpt net_encoder.pth
DONE

Loading ckpt net_decoder.pth
DONE



RecursiveDecoder(
  (box_decoder): BoxDecoder(
    (xy): Linear(in_features=128, out_features=2, bias=True)
    (size): Linear(in_features=128, out_features=2, bias=True)
  )
  (sem_decoder): SemDecoder(
    (decoder): Linear(in_features=128, out_features=7, bias=True)
  )
  (leaf_decoder): LeafDecoder(
    (decoder): Linear(in_features=128, out_features=128, bias=True)
  )
  (vertical_decoder): BranchDecoder(
    (mlp_parent): Linear(in_features=128, out_features=640, bias=True)
    (mlp_exists): Linear(in_features=128, out_features=1, bias=True)
    (mlp_arrange): Linear(in_features=128, out_features=20, bias=True)
    (mlp_sem): Linear(in_features=128, out_features=7, bias=True)
    (mlp_child): Linear(in_features=128, out_features=128, bias=True)
  )
  (horizontal_decoder): BranchDecoder(
    (mlp_parent): Linear(in_features=128, out_features=640, bias=True)
    (mlp_exists): Linear(in_features=128, out_features=1, bias=True)
    (mlp_arrange): Linear(in_features=128, out_features=

In [5]:
# globally interpolate shapes
with torch.no_grad():

    # load the two shapes as the inputs
    obj1 = LayoutDataset.load_object(os.path.join(root_dir, shape_id1 + '.json'))
    obj1.to(device)
    obj2 = LayoutDataset.load_object(os.path.join(root_dir, shape_id2 + '.json'))
    obj2.to(device)

    # store interpolated results for visuals
    obj_arr_outs = []
    obj_rel_outs = []
    obj_abs_outs = []

    # STUDENT CODE START
    # feed through the encoder to get two codes z1 and z2
    z1 = encoder.encode_structure(obj1)
    z2 = encoder.encode_structure(obj2)

In [6]:
# create a forloop looping 0, 1, 2, ..., num_interp - 1, num_interp
# interpolate the feature so that the first feature is exactly z1 and the last is exactly z2
for i in range(num_interp+1):
    alpha = i / num_interp
    code = (1 - alpha) * z1 + alpha * z2
    
    # infer through the decoder to get the iterpolate output
    # set maximal tree depth to conf.max_tree_depth
    obj_arr = decoder.decode_structure(z=code, max_depth=conf.max_tree_depth)
    obj_rel = decoder.decode_structure(z=code, max_depth=conf.max_tree_depth)
    obj_abs = decoder.decode_structure(z=code, max_depth=conf.max_tree_depth)

    obj_arr.get_arrbox()
    obj_rel.get_relbox()
    obj_abs.get_absbox()
    
    # add to the list obj_outs
    obj_arr_outs.append(obj_arr)
    obj_rel_outs.append(obj_rel)
    obj_abs_outs.append(obj_abs)

obj_names = []
for i in range(num_interp+1):
    obj_names.append('interp-%d'%i)

    # output the hierarchy
    with open(os.path.join(out_dir, 'step-%d.txt'%i), 'w') as fout:
        fout.write(str(obj_arr_outs)+'\n\n')

# output the assembled box-shape
vis_utils.draw_partnet_objects(obj_arr_outs, object_names=obj_names, \
        out_fn=os.path.join(out_dir, 'interp_figs_arr.png'), figsize=(60, 6), \
        leafs_only=True, sem_colors_filename='./part_colors_magazine.txt')

vis_utils.draw_partnet_objects(obj_rel_outs, object_names=obj_names, \
        out_fn=os.path.join(out_dir, 'interp_figs_rel.png'), figsize=(60, 6), \
        leafs_only=True, sem_colors_filename='./part_colors_magazine.txt')

vis_utils.draw_partnet_objects(obj_abs_outs, object_names=obj_names, \
        out_fn=os.path.join(out_dir, 'interp_figs_abs.png'), figsize=(60, 6), \
        leafs_only=True, sem_colors_filename='./part_colors_magazine.txt')