In [None]:
#!/usr/bin/env python3
# Copyright 2004-present Facebook. All Rights Reserved.

In [3]:
import argparse
import json
import logging
import os
import random
import time
import torch

In [2]:
import deep_sdf
import deep_sdf.workspace as ws

In [None]:
def reconstruct(
    decoder,
    num_iterations,
    latent_size,
    test_sdf,
    stat,
    clamp_dist,
    num_samples=30000,
    lr=5e-4,
    l2reg=False,
):
    def adjust_learning_rate(
        initial_lr, optimizer, num_iterations, decreased_by, adjust_lr_every
    ):
        lr = initial_lr * ((1 / decreased_by) ** (num_iterations // adjust_lr_every))
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
    decreased_by = 10
    adjust_lr_every = int(num_iterations / 2)
    if type(stat) == type(0.1):
        latent = torch.ones(1, latent_size).normal_(mean=0, std=stat).cuda()
    else:
        latent = torch.normal(stat[0].detach(), stat[1].detach()).cuda()
    latent.requires_grad = True
    optimizer = torch.optim.Adam([latent], lr=lr)
    loss_num = 0
    loss_l1 = torch.nn.L1Loss()
    for e in range(num_iterations):
        decoder.eval()
        sdf_data = deep_sdf.data.unpack_sdf_samples_from_ram(
            test_sdf, num_samples
        ).cuda()
        # sdf_dataset = deep_sdf.data.MSDFSamples(
        # data_source, train_split, num_samp_per_scene
        # ).cuda()
        xyz = sdf_data[:, 0:3]
        sdf_gt = sdf_data[:, 3].unsqueeze(1)
        sdf_gt = torch.clamp(sdf_gt, -clamp_dist, clamp_dist)
        adjust_learning_rate(lr, optimizer, e, decreased_by, adjust_lr_every)
        optimizer.zero_grad()
        latent_inputs = latent.expand(num_samples, -1)
        inputs = torch.cat([latent_inputs, xyz], 1).cuda()
        pred_sdf = decoder(inputs)

        # TODO: why is this needed?
        if e == 0:
            pred_sdf = decoder(inputs)
        pred_sdf = torch.clamp(pred_sdf, -clamp_dist, clamp_dist)
        loss = loss_l1(pred_sdf, sdf_gt)
        if l2reg:
            loss += 1e-4 * torch.mean(latent.pow(2))
        loss.backward()
        optimizer.step()
        if e % 50 == 0:
            logging.debug(loss.cpu().data.numpy())
            logging.debug(e)
            logging.debug(latent.norm())
        loss_num = loss.cpu().data.numpy()
    return loss_num, latent

In [None]:
def get_filenames(data_source, split):
    folders = []
    print(data_source)
    for dataset in split:
        for class_name in split[dataset]:
            for instance_name in split[dataset][class_name]:
                instance_filename = os.path.join(
                    dataset, class_name, instance_name
                )
                folder_path = os.path.join(data_source, ws.sdf_samples_subdir, instance_filename)
                if os.path.isdir(folder_path):
                    if not os.listdir(folder_path):
                        print(folder_path + ' is empty')
                    else:
                        folders += [instance_filename]
                else:
                    print(folder_path + ' does not exist')
                
    return folders

In [8]:
specs = ws.load_experiment_specifications('examples/chairs')
train_split_file = specs["TrainSplit"]

with open(train_split_file, "r") as f:
    train_split = json.load(f)

foldernames = deep_sdf.data.get_instance_filenames(specs["DataSource"], train_split)




In [11]:
for ii, folder in enumerate(foldernames):
    print(folder[:-3])

ShapeNetV2/03001627/1007e20d5e811b308351982a6e40cf41.
ShapeNetV2/03001627/1013f70851210a618f2e765c4a8ed3d.
ShapeNetV2/03001627/1015e71a0d21b127de03ab2a27ba7531.
ShapeNetV2/03001627/1016f4debe988507589aae130c1f06fb.
ShapeNetV2/03001627/1022fe7dd03f6a4d4d5ad9f13ac9f4e7.
ShapeNetV2/03001627/1028b32dc1873c2afe26a3ac360dbd4.
ShapeNetV2/03001627/1031fc859dc3177a2f84cb7932f866fd.
ShapeNetV2/03001627/1033ee86cc8bac4390962e4fb7072b86.
ShapeNetV2/03001627/103a0a413d4c3353a723872ad91e4ed1.
ShapeNetV2/03001627/103a60f3b09107df2da1314e036b435e.
ShapeNetV2/03001627/103b75dfd146976563ed57e35c972b4b.
ShapeNetV2/03001627/104256e5bb73b0b719fb4103277a6b93.
ShapeNetV2/03001627/1055f78d441d170c4f3443b22038d340.
ShapeNetV2/03001627/10709332176024ce9e47e7a22e24daa3.
ShapeNetV2/03001627/1079635b3da12a812cee4bf5d0f11ffe.
ShapeNetV2/03001627/107caefdad02cf1c8ab8e68cb52baa6a.
ShapeNetV2/03001627/108238b535eb293cd79b19c7c4f0e293.
ShapeNetV2/03001627/1093d35c2ac73bb74ca84d60642ec7e8.
ShapeNetV2/03001627/10991b3b01

In [None]:
iterations = 2000
experiment_directory = 'examples/chairs'
specs = ws.load_experiment_specifications(experiment_directory)
checkpoint = 'latest'
split_filename = specs["TestSplit"]
data_source = specs["DataSource"]
skip = False

def empirical_stat(latent_vecs, indices):
        lat_mat = torch.zeros(0).cuda()
        for ind in indices:
            lat_mat = torch.cat([lat_mat, latent_vecs[ind]], 0)
        mean = torch.mean(lat_mat, 0)
        var = torch.var(lat_mat, 0)
        return mean, var

# if not os.path.isfile(specs_filename):
#     raise Exception(
#         'The experiment directory does not include specifications file "specs.json"'
#     )

# specs = json.load(open(specs_filename))

arch = __import__("networks." + specs["NetworkArch"], fromlist=["Decoder"])

latent_size = specs["CodeLength"]

decoder = arch.Decoder(latent_size, **specs["NetworkSpecs"])

decoder = torch.nn.DataParallel(decoder)

saved_model_state = torch.load(
    os.path.join(
        experiment_directory, ws.model_params_subdir, checkpoint + ".pth"
        )
    )
saved_model_epoch = saved_model_state["epoch"]

decoder.load_state_dict(saved_model_state["model_state_dict"])

decoder = decoder.module.cuda()

with open(split_filename, "r") as f:
    split = json.load(f)

    foldernames = deep_sdf.data.get_filenames(data_source, split)

    random.shuffle(foldernames)

    logging.debug(decoder)

    err_sum = 0.0
    repeat = 1
    save_latvec_only = False
    rerun = 0

    reconstruction_dir = os.path.join(
        experiment_directory, ws.reconstructions_subdir, str(saved_model_epoch)
    )

    if not os.path.isdir(reconstruction_dir):
        os.makedirs(reconstruction_dir)

    reconstruction_meshes_dir = os.path.join(
        reconstruction_dir, ws.reconstruction_meshes_subdir
    )
    if not os.path.isdir(reconstruction_meshes_dir):
        os.makedirs(reconstruction_meshes_dir)

    reconstruction_codes_dir = os.path.join(
        reconstruction_dir, ws.reconstruction_codes_subdir
    )
    if not os.path.isdir(reconstruction_codes_dir):
        os.makedirs(reconstruction_codes_dir)

    for ii, folder in enumerate(foldernames):

        if not os.path.isdir(folder):
            continue

        if not os.listdir(folder):
            continue

        full_filename = os.path.join(data_source, ws.sdf_samples_subdir, folder)

        data_sdf = deep_sdf.data.read_sdfs_into_ram(full_filename)

        for k in range(repeat):

            if rerun > 1:
                mesh_filename = os.path.join(
                    reconstruction_meshes_dir, folder + "-" + str(k + rerun)
                )
                latent_filename = os.path.join(
                    reconstruction_codes_dir, folder + "-" + str(k + rerun) + ".pth"
                )
            else:
                mesh_filename = os.path.join(reconstruction_meshes_dir, folder)
                latent_filename = os.path.join(
                    reconstruction_codes_dir, folder + ".pth"
                )

            if (
                skip
                and os.path.isfile(mesh_filename + ".ply")
                and os.path.isfile(latent_filename)
            ):
                continue


            data_sdf[0] = data_sdf[0][torch.randperm(data_sdf[0].shape[0])]
            data_sdf[1] = data_sdf[1][torch.randperm(data_sdf[1].shape[0])]

            start = time.time()
            err, latent = reconstruct(
                decoder,
                int(iterations),
                latent_size,
                data_sdf,
                0.01,  # [emp_mean,emp_var],
                0.1,
                num_samples=8000,
                lr=5e-3,
                l2reg=True,
            )
            logging.debug("reconstruct time: {}".format(time.time() - start))
            err_sum += err
            logging.debug("current_error avg: {}".format((err_sum / (ii + 1))))
            logging.debug(ii)

            logging.debug("latent: {}".format(latent.detach().cpu().numpy()))

            decoder.eval()

            if not os.path.exists(os.path.dirname(mesh_filename)):
                os.makedirs(os.path.dirname(mesh_filename))

            if not save_latvec_only:
                start = time.time()
                with torch.no_grad():
                    deep_sdf.mesh.create_mesh(
                        decoder, latent, mesh_filename, N=256, max_batch=int(2 ** 18)
                    )
                logging.debug("total time: {}".format(time.time() - start))

            if not os.path.exists(os.path.dirname(latent_filename)):
                os.makedirs(os.path.dirname(latent_filename))

            torch.save(latent.unsqueeze(0), latent_filename)