In [3]:
import os, sys

from math import ceil
import pyvista as pv
import numpy as np
import meshplot as mp
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual, FloatSlider
from skimage import measure
from scipy.ndimage import zoom
from scipy.interpolate import interpn
from IPython.display import display
from einops import rearrange
import igl
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
import torch
from scipy import stats
import matplotlib.pyplot as plt
import pandas as pd
import open3d as o3d
from IPython.display import display
import random

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [4]:
# Meshplot left an annoying print statement in their code. Using this context manager to supress it...
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [7]:
# Dot product on the first dimension of n-dimensional arrays x and y
def dot(x, y):
    return np.einsum('i..., i... -> ...', x, y)

# Signed distance functions from Inigo Quilez https://iquilezles.org/articles/distfunctions/
# You could implement the smooth minimum operation as well to compose shapes together for more complex situations
def sdf_sphere(x, radius):
    return np.linalg.norm(x, axis=0) - radius

def sdf_capsule(x, a, b, r):
    xa = coords - a
    ba = coords - a
    h = np.clip(dot(xa, ba) / dot(ba, ba), 0., 1.)
    return np.linalg.norm(xa - ba * h) - r

def sdf_torus(x, radius, thickness):
    
    q = np.stack([np.linalg.norm(x[[0, 1]], axis=0) - radius, x[2]])
    return np.linalg.norm(q, axis=0) - thickness

# Crop an n-dimensional image with a centered cropping region
def center_crop(img, shape):
    start = [a // 2 - da // 2 for a, da in zip(img.shape, shape)]
    end = [a + b for a, b in zip(start, shape)]
    slices = tuple([slice(a, b) for a, b in zip(start, end)])
    return img[slices]

# Add noise to coordinates
def gradient_noise(x, scale, strength, seed=None):
    shape = [ceil(s / scale) for s in x.shape[1:]]
    if seed:
        np.random.seed(seed)
    scalar_noise = np.random.randn(*shape)
    scalar_noise = zoom(scalar_noise, zoom=scale)
    scalar_noise = center_crop(scalar_noise, shape=x.shape[1:])
    vector_noise = np.stack(np.gradient(scalar_noise))
    return vector_noise * strength


In [None]:
radius=0.25 
thickness=0.10
noise_scale=16
noise_strength=10
seed=50
bump_width=.01
bump_height=25

feature_range_bump_height = np.linspace(45, 8, 5, endpoint=False)
feature_range_scale = np.linspace(0.85, 1.15, 5, endpoint=False)
feature_range_angle = np.linspace(-1, 1, 5, endpoint=False)

scaler = MinMaxScaler()
scaler.fit(feature_range_angle.reshape(-1, 1))

scaler_s = MinMaxScaler()
scaler_s.fit(feature_range_scale.reshape(-1, 1))

scaler_b = MinMaxScaler()
scaler_b.fit(feature_range_bump_height.reshape(-1, 1))

labels = {}

for ids, scale in enumerate(feature_range_scale):
    for idb, bump_height in enumerate(feature_range_bump_height):
        coords = np.linspace(-1, 1, 100)
        x = np.stack(np.meshgrid(coords, coords, coords))
        
        filepath = f"/home/jakaria/torus_vis//torus_bump_{ids:04d}_{idb:04d}.ply"
        filename = filepath.split("/")[-1].split(".")[0]
        
        print(scale)
        print(bump_height)
        sdf = sdf_torus(x, radius, thickness)
        
        
        verts, faces, normals, values = measure.marching_cubes(sdf, level=0)  
        
        s = scaler_s.transform(np.array([scale]).reshape(-1,1)).item()
        b = scaler_b.transform(np.array([bump_height]).reshape(-1,1)).item()
        
        bump_width = 0.001 if idb == 4 else 0.01
        bump_angle = 1
        
        
        print(len(verts))
        print(len(faces))

        x_warp = gradient_noise(x, noise_scale, noise_strength) #### no seed, random noise
        angle = np.pi * bump_angle
        gaussian_center = np.array([np.cos(angle), np.sin(angle), 0]) * radius
        x_dist = np.linalg.norm((x - gaussian_center[:, None, None, None]), axis=0)
        x_bump = bump_height * np.e ** -(1. / bump_width * x_dist ** 2)

        x_warp += -np.stack(np.gradient(x_bump))

        x_warp = rearrange(x_warp, 'v h w d -> h w d v')
        vertex_noise = interpn([np.arange(100) for _ in range(3)], x_warp, verts)
    
        verts += vertex_noise
        
        original_center = np.mean(verts, axis=0)
        
        verts = verts*scale
        
        new_center = np.mean(verts, axis=0)
        displacement_vector = original_center - new_center
        verts += displacement_vector

        igl.write_triangle_mesh(filepath, verts, faces)

In [21]:
import torch
from reconstruction import AE
from datasets import MeshData
from utils import DataLoader
import tqdm
import numpy as np
from scipy.stats import entropy
from numpy.linalg import norm
from sklearn.metrics import accuracy_score, mean_squared_error
import os
import os.path as osp
from glob import glob
import openmesh as om

In [22]:
device = torch.device('cuda', 1)
# Set the path to the saved model directory
#model_path = "/home/jakaria/torus_bump_500_three_scale_binary_bump_variable_noise_fixed_angle/models_classification_regression_only_correlation_loss/models/65"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_contrastive_inhib/253"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_guided/30"# Load the saved model
model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_attribute/23"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_guided/44"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_contrastive_inhib/172"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_attribute/99"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_only_bvae/10"

model_state_dict = torch.load(f"{model_path}/model_state_dict.pt")
in_channels = torch.load(f"{model_path}/in_channels.pt")
out_channels = torch.load(f"{model_path}/out_channels.pt")
latent_channels = torch.load(f"{model_path}/latent_channels.pt")
spiral_indices_list = torch.load(f"{model_path}/spiral_indices_list.pt")
up_transform_list = torch.load(f"{model_path}/up_transform_list.pt")
down_transform_list = torch.load(f"{model_path}/down_transform_list.pt")
std = torch.load(f"{model_path}/std.pt")
mean = torch.load(f"{model_path}/mean.pt")
template_face = torch.load(f"{model_path}/faces.pt")

# Create an instance of the model
model = AE(in_channels, out_channels, latent_channels,
           spiral_indices_list, down_transform_list,
           up_transform_list)
model.load_state_dict(model_state_dict)
model.to(device)
# Set the model to evaluation mode
model.eval()

AE(
  (en_layers): ModuleList(
    (0): SpiralEnblock(
      (conv): SpiralConv(3, 32, seq_length=9)
    )
    (1-2): 2 x SpiralEnblock(
      (conv): SpiralConv(32, 32, seq_length=9)
    )
    (3): SpiralEnblock(
      (conv): SpiralConv(32, 64, seq_length=9)
    )
    (4): Linear(in_features=6272, out_features=24, bias=True)
  )
  (de_layers): ModuleList(
    (0): Linear(in_features=12, out_features=6272, bias=True)
    (1): SpiralDeblock(
      (conv): SpiralConv(64, 64, seq_length=9)
    )
    (2): SpiralDeblock(
      (conv): SpiralConv(64, 32, seq_length=9)
    )
    (3-4): 2 x SpiralDeblock(
      (conv): SpiralConv(32, 32, seq_length=9)
    )
    (5): SpiralConv(32, 3, seq_length=9)
  )
  (cls_sq): Sequential(
    (0): Linear(in_features=1, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=8, out_features=8, bias=True)
    (4): BatchN

In [20]:
fps = glob('/home/jakaria/torus_vis/*.ply')

In [32]:
mesh_predict = []
x_data = []
with torch.no_grad():
    for idx, fp in enumerate(fps):
        mesh = om.read_trimesh(fp)
        face = torch.from_numpy(mesh.face_vertex_indices()).T.type(torch.long)
        x = torch.tensor(mesh.points().astype('float32'))
        x = x.view(1, -1, 3).to(device)
        x = (x - mean.to(device)) / std.to(device)
        
        pred = model(x)
        pred, mu, log_var, re, re2 = model(x)

        reshaped_pred = (pred.view(-1, 3).cpu() * std) + mean
        reshaped_x = (x.view(-1, 3).cpu() * std) + mean
                
        reshaped_pred = reshaped_pred.cpu().numpy()
        mesh_predict.append(reshaped_pred)

        reshaped_x = reshaped_x.cpu().numpy()
        x_data.append(reshaped_x)
        #Save the reshaped prediction as a NumPy array
        reshaped_pred *= 300
        reshaped_x *= 300
        # Save the prediction as a PLY file
        subject = fp.split("/")[-1].split(".")[0]
        filepath = f"/home/jakaria/torus_vis/torus_bump_reconstructed{subject}.ply"
        igl.write_triangle_mesh(filepath, reshaped_pred, mesh.face_vertex_indices())


        mesh_predict_np = np.array(mesh_predict)
        np.save("/home/jakaria/torus_vis/mesh_predict.npy", mesh_predict_np)

        x_data_np = np.array(x_data)
        np.save(f"/home/jakaria/torus_vis/x_data_np.npy", x_data_np)
