In [3]:
from ipywidgets import FloatSlider
from IPython.display import display
import torch
import numpy as np
from deep_sdf import data, utils, mesh
import deep_sdf.workspace as ws
import json 
import os
import meshplot as mp
import trimesh
import sys

experiment_directory = "examples/torus/Run_w_con_loss_one_variable_usual"  # Path to the saved decoder model
# Load the decoder
specs_filename = os.path.join(experiment_directory, "specs.json")

if not os.path.isfile(specs_filename):
    raise Exception(
        'The experiment directory does not include specifications file "specs.json"'
    )

specs = json.load(open(specs_filename))

arch = __import__("networks." + specs["NetworkArch"], fromlist=["Decoder"])

latent_size = specs["CodeLength"]

decoder = arch.Decoder(latent_size, **specs["NetworkSpecs"])

decoder = torch.nn.DataParallel(decoder)

saved_model_state = torch.load(
    os.path.join(
        experiment_directory, ws.model_params_subdir, "latest" + ".pth"
    )
)
saved_model_epoch = saved_model_state["epoch"]

decoder.load_state_dict(saved_model_state["model_state_dict"])

decoder = decoder.module.cuda()
decoder.eval()

Decoder(
  (lin0): Linear(in_features=19, out_features=512, bias=True)
  (lin1): Linear(in_features=512, out_features=512, bias=True)
  (lin2): Linear(in_features=512, out_features=512, bias=True)
  (lin3): Linear(in_features=512, out_features=493, bias=True)
  (lin4): Linear(in_features=512, out_features=512, bias=True)
  (lin5): Linear(in_features=512, out_features=512, bias=True)
  (lin6): Linear(in_features=512, out_features=512, bias=True)
  (lin7): Linear(in_features=512, out_features=512, bias=True)
  (lin8): Linear(in_features=512, out_features=1, bias=True)
  (relu): ReLU()
  (th): Tanh()
)

In [2]:
# Meshplot left an annoying print statement in their code. Using this context manager to supress it...
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [14]:
# Initial latent vectors
z = torch.zeros(16)
plot=None
@mp.interact(**{f'z[{i}]': FloatSlider(min=-1, max=1, step=0.2, value=0) for i in range(16)})
def show(**kwargs):
    global plot
    global z
    z = torch.tensor([kwargs[f'z[{i}]'] for i in range(16)])
    z = z.to('cuda')

    # Generate the 3D shape using the decoder
    mesh_filename = "utils/interpolation"
    with torch.no_grad():
        mesh.create_mesh(decoder, z, mesh_filename, N=256, max_batch=int(2 ** 18))
    
    # Load the mesh using trimesh
    loaded_mesh = trimesh.load("utils/interpolation.ply")

    # Extract vertices and faces
    verts = np.array(loaded_mesh.vertices)
    faces = np.array(loaded_mesh.faces)

    # Set the color to white for all vertices
    white_color = np.array([1.0, 1.0, 1.0])  # RGB values for white 

    if plot is None:
        plot = mp.plot(verts, faces, c = white_color, return_plot=True)
    else:
        with HiddenPrints():
            plot.update_object(vertices=verts, faces=faces)
        display(plot._renderer)

interactive(children=(FloatSlider(value=0.0, description='z[0]', max=1.0, min=-1.0, step=0.2), FloatSlider(val…