In [2]:
import torch
from reconstruction import AE
from datasets import MeshData
from utils import utils, DataLoader, mesh_sampling, sap
import numpy as np
import pyvista as pv
from skimage import measure
from ipywidgets import interact, interactive, fixed, interact_manual, FloatSlider
from IPython.display import display
import meshplot as mp
import os, sys
from math import ceil
from scipy.ndimage import zoom
import open3d as o3d

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
# Meshplot left an annoying print statement in their code. Using this context manager to supress it...
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [11]:
device = torch.device('cuda', 0)
# Set the path to the saved model directory
model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus_two/models_two/2"

# Load the saved model
model_state_dict = torch.load(f"{model_path}/model_state_dict.pt")
in_channels = torch.load(f"{model_path}/in_channels.pt")
out_channels = torch.load(f"{model_path}/out_channels.pt")
latent_channels = torch.load(f"{model_path}/latent_channels.pt")
spiral_indices_list = torch.load(f"{model_path}/spiral_indices_list.pt")
up_transform_list = torch.load(f"{model_path}/up_transform_list.pt")
down_transform_list = torch.load(f"{model_path}/down_transform_list.pt")
std = torch.load(f"{model_path}/std.pt")
mean = torch.load(f"{model_path}/mean.pt")
template_face = torch.load(f"{model_path}/faces.pt")

# Create an instance of the model
model = AE(in_channels, out_channels, latent_channels,
           spiral_indices_list, down_transform_list,
           up_transform_list)
model.load_state_dict(model_state_dict)
model.to(device)
# Set the model to evaluation mode
model.eval()

AE(
  (en_layers): ModuleList(
    (0): SpiralEnblock(
      (conv): SpiralConv(3, 16, seq_length=9)
    )
    (1-2): 2 x SpiralEnblock(
      (conv): SpiralConv(16, 16, seq_length=9)
    )
    (3): SpiralEnblock(
      (conv): SpiralConv(16, 32, seq_length=9)
    )
    (4): Linear(in_features=3136, out_features=32, bias=True)
  )
  (de_layers): ModuleList(
    (0): Linear(in_features=16, out_features=3136, bias=True)
    (1): SpiralDeblock(
      (conv): SpiralConv(32, 32, seq_length=9)
    )
    (2): SpiralDeblock(
      (conv): SpiralConv(32, 16, seq_length=9)
    )
    (3-4): 2 x SpiralDeblock(
      (conv): SpiralConv(16, 16, seq_length=9)
    )
    (5): SpiralConv(16, 3, seq_length=9)
  )
  (reg_sq): Sequential(
    (0): Linear(in_features=1, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=8, out_features=8, bias=True)
    (4): BatchN

In [15]:
z = torch.zeros(16)
plot=None
@mp.interact(**{f'z[{i}]': FloatSlider(min=-3.0, max=3.0, step=0.2, value=0) for i in range(6)})
def show(**kwargs):
    global plot
    global z
    z[:6] = torch.tensor([kwargs[f'z[{i}]'] for i in range(6)])
    with torch.no_grad():
        z = z.to(device)
        #print(z)
        pred = model.decoder(z)

        reshaped_pred = (pred.view(-1, 3).cpu() * std) + mean
        reshaped_pred = reshaped_pred.cpu().numpy()
        print(reshaped_pred.shape)

    verts = reshaped_pred
    pcd = o3d.io.read_triangle_mesh('/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/template/template.ply')
    faces = np.asarray(pcd.triangles)

    if plot is None:
        plot = mp.plot(verts, faces, return_plot=True)
    else:
        with HiddenPrints():
            plot.update_object(vertices=verts, faces=faces)
        display(plot._renderer)

interactive(children=(FloatSlider(value=0.0, description='z[0]', max=3.0, min=-3.0, step=0.2), FloatSlider(val…