In [None]:
import ipywidgets as widgets
from IPython.display import display
import torch
import numpy as np
from deep_sdf import data, utils, mesh
import deep_sdf.workspace as ws
import json 

saved_model_path = "examples/torus"  # Path to the saved decoder model
# Load the decoder
specs_filename = os.path.join(args.experiment_directory, "specs.json")

if not os.path.isfile(specs_filename):
    raise Exception(
        'The experiment directory does not include specifications file "specs.json"'
    )

specs = json.load(open(specs_filename))

arch = __import__("networks." + specs["NetworkArch"], fromlist=["Decoder"])

latent_size = specs["CodeLength"]

decoder = arch.Decoder(latent_size, **specs["NetworkSpecs"])

decoder = torch.nn.DataParallel(decoder)

saved_model_state = torch.load(
    os.path.join(
        args.experiment_directory, ws.model_params_subdir, args.checkpoint + ".pth"
    )
)
saved_model_epoch = saved_model_state["epoch"]

decoder.load_state_dict(saved_model_state["model_state_dict"])

decoder = decoder.module.cuda()
decoder.eval()

# Initial latent vector (default: zero mean Gaussian)
latent = torch.zeros(1, 16).cuda()  # Assuming 16 dimensions in latent vector

# Function to update the latent vector and visualize the shape
def update_shape(**kwargs):
    for i in range(16):
        latent[0, i] = kwargs[f"var_{i}"]  # Update latent values based on sliders
    
    # Generate the 3D shape using the decoder
    mesh_filename = "live_shape.ply"
    with torch.no_grad():
        mesh.create_mesh(decoder, latent, mesh_filename, N=256, max_batch=int(2 ** 18))
    
    # Visualize the shape (can use external tools or libraries like trimesh)
    print(f"Generated shape saved to: {mesh_filename}")

# Create sliders for each latent variable
sliders = {f"var_{i}": widgets.FloatSlider(value=0.0, min=-3.0, max=3.0, step=0.1, description=f"Var {i}") for i in range(16)}

# Create interactive widgets
interactive_plot = widgets.interactive(update_shape, **sliders)
display(interactive_plot)