In [83]:
import torch
import os
import numpy as np
from matplotlib import pyplot as plt
%matplotlib widget
from mise import MISE
from networks.networks_lgan import Generator, Discriminator
from networks.networks_seq2seq import Seq2SeqAE
from networks.networks_partae import PartImNetAE
from dataset.data_utils import n_parts_map
from util.visualization import partsdf2mesh, partsdf2voxel

In [60]:
# load networks
gan_checkpoint = torch.load(os.path.join("chkt_dir", "lgan", "ckpt_epoch90000.pth"))
n_dim = 128 # dimension of noise vector'
h_dim = 2048 # dimension of MLP hidden layer
z_dim = 1024 # dimension of shape code, why is it different from chair latent dim, hidden_size?
netG = Generator(n_dim, h_dim, z_dim).cuda()
netG.load_state_dict(gan_checkpoint['netG_state_dict'])
netD = Discriminator(h_dim, z_dim).cuda()
netD.load_state_dict(gan_checkpoint['netD_state_dict'])

max_n_parts = 9
en_z_dim = 128 # part latent dim
hidden_size = 256 # chair latent dim
boxparam_size = 6 # dimension for part box parameters
part_feat_size = en_z_dim + boxparam_size
en_input_size = part_feat_size + n_parts_map(max_n_parts) + 1 # seq2seq input size
de_input_size = part_feat_size
n_layer = 2
max_length = 10 # max seq length
netSeq2Seq = Seq2SeqAE(en_input_size, de_input_size, hidden_size)
seq2seq_checkpoint = torch.load(os.path.join("chkt_dir", "seq2seq", "ckpt_epoch2000.pth"))
netSeq2Seq.load_state_dict(seq2seq_checkpoint["model_state_dict"])
netDecoder = netSeq2Seq.decoder.cuda()
del netSeq2Seq
en_n_layers = 5
en_f_dim = 32
de_n_layers = 6
de_f_dim = 128
imnet_checkpoint = torch.load(os.path.join('chkt_dir', 'partae', 'latest.pth'))
part_imnet = PartImNetAE(en_n_layers, en_f_dim, de_n_layers, de_f_dim, en_z_dim)
part_imnet.load_state_dict(imnet_checkpoint['model_state_dict'])
part_encoder = part_imnet.encoder.cuda()
part_decoder = part_imnet.decoder.cuda()

nets = [netG, netD, netDecoder, part_decoder]
for n in nets:
    for p in n.parameters():
        p.requires_grad = False

In [44]:
def infer_decoder(decoder, decoder_hidden, length=None): # outputs for a series of parts
    decoder_outputs = []
    decoder_input = decoder.init_input.detach().repeat(1, 1, 1).cuda()
    for di in range(max_length):
        decoder_output, decoder_hidden, output_seq, stop_sign = decoder(decoder_input, decoder_hidden)
        decoder_outputs.append(output_seq)
        if length is not None:
            if di == length - 1:
                break
        elif torch.sigmoid(stop_sign[0, 0]) > 0.5:
            # stop condition
            break
        decoder_input = output_seq.detach().unsqueeze(0)  # using target seq as input
    decoder_outputs = torch.stack(decoder_outputs, dim=0)
    return {"boxparams":decoder_outputs[:, :, -boxparam_size:], "vecs":decoder_outputs[:,:,:-boxparam_size]}

In [9]:
def eval_part_sdf(self, part_idx):
    """get output part sdf

    :param part_idx: int
    :return: all_points: (n_points, 3)
             all_values: (n_points, )
    """
    mesh_extractor = MISE(self.vox_dim, self.upsampling_steps, self.threshold)

    points = mesh_extractor.query()

    while points.shape[0] != 0:
        # Query points
        pointsf = torch.FloatTensor(points).cuda()
        # rescale to range (0, 1)
        pointsf = pointsf / mesh_extractor.resolution

        values = self.eval_part_points(pointsf, part_idx).astype(np.double)

        mesh_extractor.update(points, values)
        points = mesh_extractor.query()

    all_points, all_values = mesh_extractor.get_points()
    return all_points, all_values

In [121]:
vox_dim = 64
threshold = 0.5 # for isosurface reconstruction
upsampling_steps = 0 # result resolution = 64 * 2^steps
resolution = vox_dim * (1 << upsampling_steps)
points_batch_size = 16*16*16*4
bypart = True


def infer_part_decoder(part_codes, points):
    """run part ae decoder to calculate part sdf

    :param part_codes: (n_parts, 1, en_z_dim)
    :param points: (n_parts, n_points, 3) value range (0, 1)
    :return: out: ndarray (n_parts, n_points, 1) output sdf values for each point
    """
    pred_n_parts = part_codes.shape[0]
    if points.size(0) != pred_n_parts:
        raise RuntimeError("pred:{} gt:{}".format(pred_n_parts, points.size(0)))
    n_points = points.size(1)

    num = n_points // points_batch_size
    if n_points % points_batch_size > 0:
        num += 1
    output_sdf = []
    for i in range(num):
        batch_points = points[:, i * points_batch_size:(i + 1) * points_batch_size, :]
        cur_n_points = batch_points.size(1)
        batch_z = part_codes.repeat((1, cur_n_points, 1)).view(-1, part_codes.size(-1))
        batch_points = batch_points.contiguous().view(-1, 3)

        out = part_decoder(batch_points, batch_z)
        out = out.view((pred_n_parts, cur_n_points, -1))
        out = out.detach().cpu().numpy()
        output_sdf.append(out)
    output_sdf = np.concatenate(output_sdf, axis=1)
    return output_sdf


def eval_part_points(points, part_vec):
    values = infer_part_decoder(part_vec, points.unsqueeze(0))
    return values.squeeze()

def eval_part_sdf(part_vec):
    """get output part sdf

    :param part_vec: [1, z_dim]
    :return: all_points: (n_points, 3)
             all_values: (n_points, )
    """
    mesh_extractor = MISE(vox_dim, upsampling_steps, threshold)

    points = mesh_extractor.query()

    while points.shape[0] != 0:
        # Query points
        pointsf = torch.FloatTensor(points).cuda()
        # rescale to range (0, 1)
        pointsf = pointsf / mesh_extractor.resolution

        values = eval_part_points(pointsf, part_vec).astype(np.double)

        mesh_extractor.update(points, values)
        points = mesh_extractor.query()

    all_points, all_values = mesh_extractor.get_points()
    return all_points, all_values

def transform_points(points, values, transforms):
    """transform part points from local frame to global frame

    :transforms: [n_parts, 1, 6] same thing as boxparams
    :param points: (n_parts, n_points, 3) or [(n_points1, 3), (n_points2, 3), ...], in range (0, vox_dim)
    :param values: (n_parts, n_points, 1) or [(n_points1, 1), (n_points2, 1), ...]
    :return:
    """
    cube_mid = np.asarray([resolution // 2, resolution // 2, resolution // 2]).reshape(1, 3)
    new_points, new_values = [], []
    for idx in range(len(points)):
        part_points = points[idx]
        part_values = values[idx]
        part_translation = (transforms[idx, 0, :3].reshape(1, 3) * resolution).detach().cpu().numpy()
        part_size = transforms[idx, 0, 3:6].reshape(3).detach().cpu().numpy()

        part_scale = np.amax(part_size)
        part_points = (part_points - cube_mid) * part_scale + part_translation

        mins = part_translation - part_size * resolution / 2
        maxs = part_translation + part_size * resolution / 2
        in_bbox_indice = np.max(part_points - maxs, axis=1)
        in_bbox_indice = np.where(in_bbox_indice <= 0)[0]
        part_points = part_points[in_bbox_indice, :]
        part_values = part_values[in_bbox_indice]

        in_bbox_indice = np.max(part_points - mins, axis=1)
        in_bbox_indice = np.where(in_bbox_indice >= 0)[0]
        part_points = part_points[in_bbox_indice, :]
        part_values = part_values[in_bbox_indice]

        part_points = np.clip(part_points, 0, resolution - 1)
        new_points.append(part_points)
        new_values.append(part_values)
    return new_points, new_values

def generate_shape(part_vecs, transforms, by_part=True):
    """generate final shape geometry

    :param part_vecs: [n_parts, 1, z_dim]
    :param format: str. output geometry format
    :param by_part: bool. segment each part or put as a whole
    :return:
    """
    points = []
    values = []
    for idx in range(part_vecs.shape[0]):
        part_points, part_values = eval_part_sdf(part_vecs[idx])
        points.append(part_points)
        values.append(part_values)
    points, values = transform_points(points, values, transforms)
    # shape_voxel = partsdf2voxel(points, values, vox_dim=resolution, by_part=by_part)
    shape_mesh = partsdf2mesh(points, values, affine=None, vox_dim=resolution, by_part=by_part)
    return {'mesh':shape_mesh}

def save_output(shape, filename, save_dir, form):
    if form == 'voxel':
        save_path = os.path.join(save_dir, "{}.h5".format(filename))
        with h5py.File(save_path, 'w') as fp:
            fp.create_dataset('voxel', data=shape, compression=9)
    elif form == "mesh":
        save_path = os.path.join(save_dir, "{}.obj".format(filename))
        shape.export(save_path)
    else:
        raise NotImplementedError
        
def decode_n_save_shape_mesh(part_vecs, transforms, save_dir, filename):
    mesh = generate_shape(part_vecs, transforms)['mesh']
    save_output(mesh, filename, save_dir, form="mesh")

In [54]:
def test_pqnet_generation():
    noise = torch.randn(n_dim).cuda()
    with torch.no_grad():
        fake = netG(noise)
        score = netD(fake)
        seq = infer_decoder(netDecoder, fake.view(2, 1, hidden_size*2))
        boxs = seq['boxparams'] # # [n_parts, 1, boxparam_dim]
        vecs = seq['vecs'] # [n_parts, 1, en_z_dim]
        decode_n_save_shape_mesh(vecs, boxs, os.path.join('test','meshes'), 'model_pqnet_gan')

test_pqnet_generation()

In [66]:
# data utilities
import h5py

json_dir = "data/parts_json"

def loadH5Full(path, resolution=64, rescale=True):
    with h5py.File(path, 'r') as data_dict:
        nParts = data_dict.attrs['n_parts']
        partVoxel = data_dict['parts_voxel_scaled64'][:].astype(np.float)
        dataPoints = data_dict['points_{}'.format(resolution)][:]
        dataVals = data_dict['values_{}'.format(resolution)][:]
        translation = data_dict['translations'][:]
        scale = data_dict['scales'][:]
        size = data_dict['size'][:]
    if rescale:
        dataPoints = dataPoints / resolution
    return nParts, partVoxel, dataPoints, dataVals, scale, translation, size

def readJsonPartCategories(chairID, nParts):
    categories = []
    path = os.path.join(json_dir, chairID, "result.json")
    try:
        f = open(path)
        data = json.load(f)
        for (i, el) in enumerate(data[0]['children']):
            partName = el['text']
            if partName == "Chair Back":
                categories.append(1)
            elif partName == "Chair Seat":
                categories.append(2)
            elif partName == "Chair Arm":
                categories.append(3)
            elif partName == "Chair Base":
                # this will also apply to following parts
                for j in range(i, nParts):
                    categories.append(4)
        f.close()
    except Exception as e:
        categories = [0]*nParts # missing corresponding json labels for input h5
    return categories[0:nParts] # sometimes json file has more labels than voxelized parts, in which case truncated

def getChairPartInfos(encoder, data_dir, filename):
    path = os.path.join(data_dir, filename)

    nParts, partVoxel, dataPoints, dataVals, scales, translations, size = loadH5Full(path, resolution=64)
    voxelTensor = torch.tensor(partVoxel.astype(np.float), dtype=torch.float32).unsqueeze(1)  # (nParts, 1, dim, dim, dim)
    with torch.no_grad():
        latentVecs = encoder(voxelTensor.cuda()).cpu().numpy()
    nParts = latentVecs.shape[0]
    strid = filename[0:-3]
    categories = np.array(readJsonPartCategories(strid, nParts))

    # numpy arrays
    # vecs: (nParts, latent dimension)
    # scales: (nParts, 1)
    # translations: (nParts, 3)
    # categories: (nParts)
    return {'vecs':latentVecs, 'scales':scales, 'translations':translations, 'categories':categories,
            'filenames':[filename]*len(latentVecs)}

def loadAllChairsInfoIterable(data_dir):
    # part latent vectors, categories, and affine transforms
    filenames = filter(lambda filename:filename.endswith(".h5"), os.listdir(data_dir))    
    chairInfos = filter(lambda x:x is not None, map(lambda filename: getChairPartInfos(part_encoder, data_dir, filename), filenames))
    
    return chairInfos

In [141]:
from functools import reduce
def load_targets_info(target_dir):
    targets = list(loadAllChairsInfoIterable(target_dir))
    part_originfiles = np.array(reduce(lambda l1, l2: l1+l2, map(lambda x:x['filenames'], targets)))
    part_originindex = np.concatenate(list(map(lambda x:np.arange(0, len(x['filenames'])), targets)))
    part_vecs = np.concatenate(list(map(lambda x:x['vecs'], targets)))
    return {'vecs':part_vecs, 'originfiles':part_originfiles, 'originindex':part_originindex}


save_dir = os.path.join('test','meshes')
##########################
# adversarial parameters #
##########################
targets_path = "data/TestData/set2"  # change this to new data as needed
num_search = 4
max_search_iter = 3000
sample_size = 64
learning_rate = 1e-4
# uses weight decay to enforce gaussian prior for generator input
# https://stats.stackexchange.com/questions/163388/why-is-the-l2-regularization-equivalent-to-gaussian-prior
weight_decay_param = 0

#########################
# adversarial algorithm #
#########################
targets = load_targets_info(targets_path)
targetvecs = torch.cuda.FloatTensor(targets['vecs']).view(1, -1, en_z_dim) # [1, n_targetparts, en_z_dim]
for kth_search in range(num_search):
    adv_noise = torch.randn(n_dim, requires_grad=True, device="cuda")  # N(0,1)
    adv_adam = torch.optim.Adam([adv_noise], lr=learning_rate, betas=(0.5, 0.9), weight_decay=weight_decay_param)
    distance_tracker = [None] * max_search_iter
    scores_tracker = [None] * max_search_iter
    means_tracker = [None] * max_search_iter
    vars_tracker = [None] * max_search_iter
    for ith_iter in range(max_search_iter):
        fake = netG(adv_noise)
        seq = infer_decoder(netDecoder, fake.view(2, 1, hidden_size*2))
        vecs = seq['vecs'] # [n_parts, 1, en_z_dim]
        distances = torch.norm(vecs - targetvecs, dim=2) # [n_parts, n_targetparts]
        shortest_distances, closest_targets = torch.min(distances, dim=1) # [n_parts]
        loss = torch.sum(shortest_distances)
        loss.backward()
        adv_adam.step()
        adv_adam.zero_grad()
        # for debugging
        with torch.no_grad():
            score = netD(fake)
        distance_tracker[ith_iter] = loss.detach().to('cpu', non_blocking=True)
        scores_tracker[ith_iter] = score.detach().to('cpu', non_blocking=True)
        means_tracker[ith_iter] = torch.mean(adv_noise).detach().to('cpu', non_blocking=True)
        vars_tracker[ith_iter] = torch.var(adv_noise, unbiased=True).detach().to('cpu', non_blocking=True)
    
    # visualizations & output
    if num_search == 1: # plot debugging info when there is only one search
        distance_tracker = np.array([distance.numpy() for distance in distance_tracker]) # [n_iterations]
        scores_tracker = np.array([score.numpy() for score in scores_tracker]) # [n_iterations]
        fig, ax = plt.subplots(2,2, figsize=(10,10))
        ax[0, 0].set_title("Total Distances over Time")
        ax[0, 0].plot(distance_tracker, marker='o')
        ax[0, 0].set_xlabel("iteration")
        ax[0, 0].set_ylabel("total distance")
        ax[0, 1].set_title("Critic Score over Time")
        ax[0, 1].plot(scores_tracker, marker='o')
        ax[0, 1].set_xlabel("iteration")
        ax[0, 1].set_ylabel("critic score")
        ax[1, 0].set_title("Input Means over Time")
        ax[1, 0].plot(means_tracker, marker='o')
        ax[1, 0].set_xlabel("iteration")
        ax[1, 0].set_ylabel("mean")
        ax[1, 1].set_title("Input Variances over Time")
        ax[1, 1].plot(vars_tracker, marker='o')
        ax[1, 1].set_xlabel("iteration")
        ax[1, 1].set_ylabel("variance")
    boxs = seq['boxparams'] # # [n_parts, 1, boxparam_dim]
    # !!!!using closest parts in target range as part outputs!!!!
    vecs = targetvecs[0, closest_targets, :].view(-1, 1, en_z_dim) # [n_parts, 1, en_z_dim]
    out_name = f'model_gana_{kth_search}'
    decode_n_save_shape_mesh(vecs, boxs, save_dir, out_name)
    print()
    print("Name:", out_name)
    print("Critic score:", scores_tracker[-1])
    print("Using parts from:", targets['originfiles'][closest_targets.cpu().numpy()])


Name: model_gana_0
Critic score: tensor([1.0693])
Using parts from: ['39055.h5' '39055.h5' '40096.h5' '40096.h5' '39055.h5' '37529.h5']

Name: model_gana_1
Critic score: tensor([-0.2922])
Using parts from: ['37529.h5' '39055.h5' '37529.h5']

Name: model_gana_2
Critic score: tensor([-0.2619])
Using parts from: ['39055.h5' '41975.h5' '37529.h5' '37529.h5' '37529.h5' '37529.h5']

Name: model_gana_3
Critic score: tensor([-0.3250])
Using parts from: ['41975.h5' '41975.h5' '41975.h5' '41975.h5']
