In [None]:
import torch
from reconstruction import AE
from datasets import MeshData
from utils import utils, DataLoader, mesh_sampling, sap
import numpy as np
import pyvista as pv
from skimage import measure
from ipywidgets import interact, interactive, fixed, interact_manual, FloatSlider
from IPython.display import display
import meshplot as mp
import os, sys
from math import ceil
from scipy.ndimage import zoom
import open3d as o3d

In [None]:
# Meshplot left an annoying print statement in their code. Using this context manager to supress it...
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [6]:
device = torch.device('cuda', 0)
# Set the path to the saved model directory
model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models/0"

# Load the saved model
model_state_dict = torch.load(f"{model_path}/model_state_dict.pt")
in_channels = torch.load(f"{model_path}/in_channels.pt")
out_channels = torch.load(f"{model_path}/out_channels.pt")
latent_channels = torch.load(f"{model_path}/latent_channels.pt")
spiral_indices_list = torch.load(f"{model_path}/spiral_indices_list.pt")
up_transform_list = torch.load(f"{model_path}/up_transform_list.pt")
down_transform_list = torch.load(f"{model_path}/down_transform_list.pt")
std = torch.load(f"{model_path}/std.pt")
mean = torch.load(f"{model_path}/mean.pt")
template_face = torch.load(f"{model_path}/faces.pt")

# Create an instance of the model
model = AE(in_channels, out_channels, latent_channels,
           spiral_indices_list, down_transform_list,
           up_transform_list)
model.load_state_dict(model_state_dict)
model.to(device)
# Set the model to evaluation mode
model.eval()

In [9]:
z = torch.zeros(16)
plot=None
@mp.interact(**{f'z[{i}]': FloatSlider(min=-1.0, max=1.0, step=0.2, value=0) for i in range(16)})
def show(**kwargs):
    global plot
    global z
    z = torch.tensor([kwargs[f'z[{i}]'] for i in range(16)])
    with torch.no_grad():
        z = z.to(device)
        #print(z)
        pred = model.decoder(z)

        reshaped_pred = (pred.view(-1, 3).cpu() * std) + mean
        reshaped_pred = reshaped_pred.cpu().numpy()
        print(reshaped_pred.shape)

    verts = reshaped_pred
    pcd = o3d.io.read_triangle_mesh('/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/template/template.ply')
    faces = np.asarray(pcd.triangles)

    if plot is None:
        plot = mp.plot(verts, faces, return_plot=True)
    else:
        with HiddenPrints():
            plot.update_object(vertices=verts, faces=faces)
        display(plot._renderer)

FloatSlider(value=0.1, max=0.3, step=0.01)

FloatSlider(value=0.05, max=0.1, min=0.01, step=0.01)

IntSlider(value=10, max=25, min=5)

FloatSlider(value=0.2, max=0.4, step=0.05)

IntSlider(value=1, min=1)

Button(description='Show', style=ButtonStyle())