In [None]:
import os, sys

from math import ceil
import pyvista as pv
import numpy as np
import meshplot as mp
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual, FloatSlider
from skimage import measure
from scipy.ndimage import zoom
from scipy.interpolate import interpn
from IPython.display import display
from einops import rearrange
import igl
from tqdm import tqdm
from sklearn.preprocessing import MinMaxScaler
import torch
from scipy import stats
import matplotlib.pyplot as plt
import pandas as pd
import open3d as o3d
from IPython.display import display
import random

In [2]:
# Meshplot left an annoying print statement in their code. Using this context manager to supress it...
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [3]:
# Dot product on the first dimension of n-dimensional arrays x and y
def dot(x, y):
    return np.einsum('i..., i... -> ...', x, y)

# Signed distance functions from Inigo Quilez https://iquilezles.org/articles/distfunctions/
# You could implement the smooth minimum operation as well to compose shapes together for more complex situations
def sdf_sphere(x, radius):
    return np.linalg.norm(x, axis=0) - radius

def sdf_capsule(x, a, b, r):
    xa = coords - a
    ba = coords - a
    h = np.clip(dot(xa, ba) / dot(ba, ba), 0., 1.)
    return np.linalg.norm(xa - ba * h) - r

def sdf_torus(x, radius, thickness):
    
    q = np.stack([np.linalg.norm(x[[0, 1]], axis=0) - radius, x[2]])
    return np.linalg.norm(q, axis=0) - thickness

# Crop an n-dimensional image with a centered cropping region
def center_crop(img, shape):
    start = [a // 2 - da // 2 for a, da in zip(img.shape, shape)]
    end = [a + b for a, b in zip(start, shape)]
    slices = tuple([slice(a, b) for a, b in zip(start, end)])
    return img[slices]

# Add noise to coordinates
def gradient_noise(x, scale, strength, seed=None):
    shape = [ceil(s / scale) for s in x.shape[1:]]
    if seed:
        np.random.seed(seed)
    scalar_noise = np.random.randn(*shape)
    scalar_noise = zoom(scalar_noise, zoom=scale)
    scalar_noise = center_crop(scalar_noise, shape=x.shape[1:])
    vector_noise = np.stack(np.gradient(scalar_noise))
    return vector_noise * strength


In [None]:
radius=0.25 
thickness=0.10
noise_scale=20
noise_strength=10
seed=19
bump_width=.01
bump_height=25

feature_range_bump_height = np.linspace(35, 1, 5, endpoint=True)
feature_range_scale = np.linspace(0.85, 1.15, 5, endpoint=False)
feature_range_angle = np.linspace(-1, 1, 5, endpoint=False)

scaler = MinMaxScaler()
scaler.fit(feature_range_angle.reshape(-1, 1))

scaler_s = MinMaxScaler()
scaler_s.fit(feature_range_scale.reshape(-1, 1))

scaler_b = MinMaxScaler()
scaler_b.fit(feature_range_bump_height.reshape(-1, 1))

labels = {}

for ids, scale in enumerate(feature_range_scale):
    for idb, bump_height in enumerate(feature_range_bump_height):
        coords = np.linspace(-1, 1, 100)
        x = np.stack(np.meshgrid(coords, coords, coords))
        
        filepath = f"/home/jakaria/torus_vis//torus_bump_{ids:04d}_{idb:04d}.ply"
        filename = filepath.split("/")[-1].split(".")[0]
        
        print(scale)
        print(bump_height)
        sdf = sdf_torus(x, radius, thickness)
        
        
        verts, faces, normals, values = measure.marching_cubes(sdf, level=0)  
        
        s = scaler_s.transform(np.array([scale]).reshape(-1,1)).item()
        b = scaler_b.transform(np.array([bump_height]).reshape(-1,1)).item()
        
        bump_width = 0.001 if idb == 4 else 0.01
        bump_angle = 1
        
        
        print(len(verts))
        print(len(faces))

        x_warp = gradient_noise(x, noise_scale, noise_strength, seed=19) #### no seed, random noise
        angle = np.pi * bump_angle
        gaussian_center = np.array([np.cos(angle), np.sin(angle), 0]) * radius
        x_dist = np.linalg.norm((x - gaussian_center[:, None, None, None]), axis=0)
        x_bump = bump_height * np.e ** -(1. / bump_width * x_dist ** 2)

        x_warp += -np.stack(np.gradient(x_bump))

        x_warp = rearrange(x_warp, 'v h w d -> h w d v')
        vertex_noise = interpn([np.arange(100) for _ in range(3)], x_warp, verts)
    
        verts += vertex_noise
        
        original_center = np.mean(verts, axis=0)
        
        verts = verts*scale
        
        new_center = np.mean(verts, axis=0)
        displacement_vector = original_center - new_center
        verts += displacement_vector
    
        igl.write_triangle_mesh(filepath, verts, faces)

In [5]:
import torch
from reconstruction import AE
from datasets import MeshData
from utils import DataLoader
import tqdm
import numpy as np
from scipy.stats import entropy
from numpy.linalg import norm
from sklearn.metrics import accuracy_score, mean_squared_error
import os
import os.path as osp
from glob import glob
import openmesh as om

In [6]:
device = torch.device('cuda', 1)
# Set the path to the saved model directory
#model_path = "/home/jakaria/torus_bump_500_three_scale_binary_bump_variable_noise_fixed_angle/models_classification_regression_only_correlation_loss/models/65"
model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_contrastive_inhib/253"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_guided/30"# Load the saved model
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_attribute/23"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_guided/44"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_contrastive_inhib/172"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_attribute/99"
#model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_only_bvae/10"

model_state_dict = torch.load(f"{model_path}/model_state_dict.pt")
in_channels = torch.load(f"{model_path}/in_channels.pt")
out_channels = torch.load(f"{model_path}/out_channels.pt")
latent_channels = torch.load(f"{model_path}/latent_channels.pt")
spiral_indices_list = torch.load(f"{model_path}/spiral_indices_list.pt")
up_transform_list = torch.load(f"{model_path}/up_transform_list.pt")
down_transform_list = torch.load(f"{model_path}/down_transform_list.pt")
std = torch.load(f"{model_path}/std.pt")
mean = torch.load(f"{model_path}/mean.pt")
template_face = torch.load(f"{model_path}/faces.pt")

# Create an instance of the model
model = AE(in_channels, out_channels, latent_channels,
           spiral_indices_list, down_transform_list,
           up_transform_list)
model.load_state_dict(model_state_dict)
model.to(device)
# Set the model to evaluation mode
model.eval()

AE(
  (en_layers): ModuleList(
    (0): SpiralEnblock(
      (conv): SpiralConv(3, 16, seq_length=9)
    )
    (1-2): 2 x SpiralEnblock(
      (conv): SpiralConv(16, 16, seq_length=9)
    )
    (3): SpiralEnblock(
      (conv): SpiralConv(16, 32, seq_length=9)
    )
    (4): Linear(in_features=3136, out_features=24, bias=True)
  )
  (de_layers): ModuleList(
    (0): Linear(in_features=12, out_features=3136, bias=True)
    (1): SpiralDeblock(
      (conv): SpiralConv(32, 32, seq_length=9)
    )
    (2): SpiralDeblock(
      (conv): SpiralConv(32, 16, seq_length=9)
    )
    (3-4): 2 x SpiralDeblock(
      (conv): SpiralConv(16, 16, seq_length=9)
    )
    (5): SpiralConv(16, 3, seq_length=9)
  )
  (cls_sq): Sequential(
    (0): Linear(in_features=1, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=8, out_features=8, bias=True)
    (4): BatchN

In [7]:
fps = sorted(glob('/home/jakaria/torus_vis/*.ply'))

In [8]:
print(fps)

['/home/jakaria/torus_vis/torus_bump_0000_0000.ply', '/home/jakaria/torus_vis/torus_bump_0000_0001.ply', '/home/jakaria/torus_vis/torus_bump_0000_0002.ply', '/home/jakaria/torus_vis/torus_bump_0000_0003.ply', '/home/jakaria/torus_vis/torus_bump_0000_0004.ply', '/home/jakaria/torus_vis/torus_bump_0001_0000.ply', '/home/jakaria/torus_vis/torus_bump_0001_0001.ply', '/home/jakaria/torus_vis/torus_bump_0001_0002.ply', '/home/jakaria/torus_vis/torus_bump_0001_0003.ply', '/home/jakaria/torus_vis/torus_bump_0001_0004.ply', '/home/jakaria/torus_vis/torus_bump_0002_0000.ply', '/home/jakaria/torus_vis/torus_bump_0002_0001.ply', '/home/jakaria/torus_vis/torus_bump_0002_0002.ply', '/home/jakaria/torus_vis/torus_bump_0002_0003.ply', '/home/jakaria/torus_vis/torus_bump_0002_0004.ply', '/home/jakaria/torus_vis/torus_bump_0003_0000.ply', '/home/jakaria/torus_vis/torus_bump_0003_0001.ply', '/home/jakaria/torus_vis/torus_bump_0003_0002.ply', '/home/jakaria/torus_vis/torus_bump_0003_0003.ply', '/home/jaka

In [25]:
mesh_predict = []
x_data = []
faces_all = []
with torch.no_grad():
    for idx, fp in enumerate(fps):
        mesh = om.read_trimesh(fp)
        face = torch.from_numpy(mesh.face_vertex_indices()).T.type(torch.long)
        x = torch.tensor(mesh.points().astype('float32'))
        x = x.view(1, -1, 3).to(device)
        x = (x - mean.to(device)) / std.to(device)
        
        pred = model(x)
        pred, mu, log_var, re, re2 = model(x)

        reshaped_pred = (pred.view(-1, 3).cpu() * std) + mean
        reshaped_x = (x.view(-1, 3).cpu() * std) + mean
                
        reshaped_pred = reshaped_pred.cpu().numpy()
        mesh_predict.append(reshaped_pred)

        reshaped_x = reshaped_x.cpu().numpy()
        x_data.append(reshaped_x)

        faces_all.append(mesh.face_vertex_indices())
        #Save the reshaped prediction as a NumPy array
        reshaped_pred *= 300
        reshaped_x *= 300

        # Save the prediction as a PLY file
        subject = fp.split("/")[-1].split(".")[0]
        filepath = f"/home/jakaria/torus_vis/reconstructed_{subject}.ply"
        igl.write_triangle_mesh(filepath, reshaped_pred, mesh.face_vertex_indices())


    mesh_predict_np = np.array(mesh_predict)
    np.save("/home/jakaria/torus_vis/mesh_predict.npy", mesh_predict_np)

    x_data_np = np.array(x_data)
    np.save(f"/home/jakaria/torus_vis/x_data_np.npy", x_data_np)

    faces_all_np = np.array(faces_all)
    np.save(f"/home/jakaria/torus_vis/faces_all.npy", faces_all_np)




In [11]:
rotation_matrix = np.array([[np.cos(np.pi), -np.sin(np.pi), 0],
                            [np.sin(np.pi), np.cos(np.pi), 0],
                            [0, 0, 1]]) #z axic

In [26]:
z = torch.zeros(12)

In [29]:
# Function to visualize, save plots, and calculate distances and volumes
def visualize_and_save(z0, z1, z_n):
    global plot
    global z

    z[0] = z0
    z[1] = z1
    z[4] = z_n

    with torch.no_grad():
        z = z.to(device)
        pred = model.decoder(z)

        reshaped_pred = (pred.view(-1, 3).cpu() * std) + mean
        verts = reshaped_pred.cpu().numpy()

    #verts = np.dot(np.asarray(verts), rotation_matrix.T)

    pcd = o3d.io.read_triangle_mesh('/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/template/template.ply')
    faces = np.asarray(pcd.triangles)

    white_color = [1.0, 1.0, 1.0]
    grey_color = [0.5, 0.5, 0.5]
     
    #mp.subplot(verts, faces, c=np.array(white_color), s=[2, 2, 0])
    #plot = mp.plot(verts, faces, c=np.array(white_color), return_plot=True)
    #save_plot(plot, f'z0_{z0:.1f}_z1_{z1:.1f}.png')
    #print(verts.shape, faces.shape)
    return verts, faces

In [31]:

# Sample and save plots for different z[0] values while fixing z[1]
verts_all = []
faces_all = []
z_n_value = 4.2 # seed 19
z0_values = np.linspace(7, -7, 5)
z1_values = np.linspace(9, -9, 5)
for z1_value in z1_values:
        for z0_value in z0_values:
             verts, faces = visualize_and_save(z0_value, z1_value, z_n_value)
             verts_all.append(verts)
             faces_all.append(faces)

In [32]:
output_dir = "/home/jakaria/torus_vis"

In [33]:
# Define file paths for saving vertices and faces
vertices_file = os.path.join(output_dir,"mesh_vertices_walk.npy")
faces_file = os.path.join(output_dir,"mesh_faces_walk.npy")


# Save vertices and faces to separate .npy files
np.save(vertices_file, verts_all)
np.save(faces_file, faces_all)

In [None]:
# Function to calculate the magnitude of change for each vertex
def calculate_magnitude_change(verts, initial_verts):
    diff = np.linalg.norm(verts - initial_verts, axis=1)
    print(diff)
    return diff

# Function to map magnitude to colors
# Define colors for specific distance ranges
def map_magnitude_to_colors(magnitude):
    colors = []
    for dist in magnitude:
        if 0.00 <= dist < 0.001:
            colors.append([0, 0, 1])         # Blue
      ## define the range here with several if else block
    return colors
# Function to move an object to the center
def move_to_center(verts):
    centroid = np.mean(verts, axis=0)
    return verts - centroid

def calculate_volume_voxelization(mesh):
    voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh, voxel_size=0.001)
    voxel_count = len(voxel_grid.get_voxels())
    voxel_volume = voxel_grid.voxel_size ** 3
    volume = voxel_count * voxel_volume
    return volume

rotation_matrix = np.array([[np.cos(np.pi), -np.sin(np.pi), 0],
                            [np.sin(np.pi), np.cos(np.pi), 0],
                            [0, 0, 1]]) #z axic
z = torch.zeros(12)
with torch.no_grad():
    z = z.to(device)
    #print(z)
    pred = model.decoder(z)

    reshaped_pred_initial = (pred.view(-1, 3).cpu() * std) + mean
    reshaped_pred_initial = reshaped_pred_initial.cpu().numpy()
    #print(reshaped_pred.shape)

verts_initial = reshaped_pred_initial
#verts = o3d.utility.Vector3dVector(np.dot(np.asarray(verts), rotation_matrix.T))
verts_initial = np.dot(np.asarray(verts_initial), rotation_matrix.T)

plot=None
sliders = {f'z[{i}]': FloatSlider(min=-3.0, max=3.0, step=0.5, value=0) for i in range(12)}
sliders['z[0]'].description = 'Disease'
sliders['z[1]'].description = 'Age'

@mp.interact(**sliders)
#@mp.interact(**{f'z[{i}]': FloatSlider(min=-2.5, max=2.5, step=0.4, value=0) for i in range(12)})
def show(**kwargs):
    global plot
    global z
    z = torch.tensor([kwargs[f'z[{i}]'] for i in range(12)])
    with torch.no_grad():
        z = z.to(device)
        #print(z)
        pred = model.decoder(z)

        reshaped_pred = (pred.view(-1, 3).cpu() * std) + mean
        reshaped_pred = reshaped_pred.cpu().numpy()
        #print(reshaped_pred.shape)

    verts = reshaped_pred
    #verts = o3d.utility.Vector3dVector(np.dot(np.asarray(verts), rotation_matrix.T))
    verts = np.dot(np.asarray(verts), rotation_matrix.T)

    pcd = o3d.io.read_triangle_mesh('/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/template/template.ply')
    faces = np.asarray(pcd.triangles)
    initial_verts = np.asarray(pcd.vertices)
    #print(verts)
    #print(faces)

    # Calculate magnitude of change
    magnitude = calculate_magnitude_change(verts, verts_initial)

    # Map magnitude to colors
    colors = map_magnitude_to_colors(magnitude)
    colors = np.asarray(colors)


    white_color = [1.0, 1.0, 1.0]
    grey_color = [0.5, 0.5, 0.5]

    if plot is None:
        #plot = mp.plot(verts_object1)
        plot = mp.plot(verts, faces, c=colors, return_plot=True)
    else:
        with HiddenPrints():
            #plot.update_object(vertices=verts, faces=faces)
            plot.update_object(vertices=verts, faces=faces, colors=colors)
        display(plot._renderer)

In [None]:
# Function to move an object to the center
def move_to_center(verts):
    centroid = np.mean(verts, axis=0)
    return verts - centroid

def calculate_volume_voxelization(mesh):
    voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh, voxel_size=0.1)
    voxel_count = len(voxel_grid.get_voxels())
    voxel_volume = voxel_grid.voxel_size ** 3
    volume = voxel_count * voxel_volume
    return volume

rotation_matrix = np.array([[np.cos(np.pi), -np.sin(np.pi), 0],
                            [np.sin(np.pi), np.cos(np.pi), 0],
                            [0, 0, 1]]) #z axic
z = torch.zeros(12)

In [None]:
import numpy as np
import os
import torch
import meshplot as mp
from ipywidgets import interact, FloatSlider
from scipy.spatial import distance
import open3d as o3d
from skimage import measure
from contextlib import contextmanager

In [None]:
# Function to calculate distance between two point clouds (meshes)
def calculate_distance(mesh1, mesh2):
    return distance.directed_hausdorff(mesh1.vertices, mesh2.vertices)[0]

# Create a directory to save plots and results
output_dir = "/home/jakaria/save_plots_ms_range"
os.makedirs(output_dir, exist_ok=True)

# Function to save plots
def save_plot(plot, filename):
    plot.save(os.path.join(output_dir, filename))

def calculate_volume_voxelization(mesh):
    voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh, voxel_size=0.001)
    voxel_count = len(voxel_grid.get_voxels())
    voxel_volume = voxel_grid.voxel_size ** 3
    volume = voxel_count * voxel_volume
    return volume

white_color = [1.0, 1.0, 1.0]
grey_color = [0.5, 0.5, 0.5]

In [None]:
# Function to visualize, save plots, and calculate distances and volumes
def visualize_and_save(z0, z1):
    global plot
    global z

    z[0] = z0
    z[1] = z1

    with torch.no_grad():
        z = z.to(device)
        pred = model.decoder(z)

        reshaped_pred = (pred.view(-1, 3).cpu() * std) + mean
        verts = reshaped_pred.cpu().numpy()

    verts = np.dot(np.asarray(verts), rotation_matrix.T)

    pcd = o3d.io.read_triangle_mesh('/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/template/template.ply')
    faces = np.asarray(pcd.triangles)

    white_color = [1.0, 1.0, 1.0]
    grey_color = [0.5, 0.5, 0.5]
     
    #mp.subplot(verts, faces, c=np.array(white_color), s=[2, 2, 0])
    #plot = mp.plot(verts, faces, c=np.array(white_color), return_plot=True)
    #save_plot(plot, f'z0_{z0:.1f}_z1_{z1:.1f}.png')
    #print(verts.shape, faces.shape)
    return verts, faces

In [None]:
# Sample and save plots for different z[0] values while fixing z[1]
verts_all = []
faces_all = []
for z1_value in np.arange(-3, 3, 0.5):
        for z0_value in [-3, 3]:
             verts, faces = visualize_and_save(z0_value, z1_value)
             verts_all.append(verts)
             faces_all.append(faces)
             scalar_values = np.array([60, 60, 50])
             verts = verts * scalar_values
             print(max(verts[:, 0]), min(verts[:, 0]))
             print(max(verts[:, 1]), min(verts[:, 1]))
             print(max(verts[:, 2]), min(verts[:, 2]))
             mesh = o3d.geometry.TriangleMesh()
             mesh.vertices = o3d.utility.Vector3dVector(verts)
             mesh.triangles = o3d.utility.Vector3iVector(faces)

             print(np.asanyarray(mesh.vertices).shape)
             print(mesh.is_orientable())
             print(mesh.is_watertight())
             
             #mesh.orient_triangles()
             #volume = calculate_volume_voxelization(mesh)
             print(mesh.get_volume())
             volume = mesh.get_volume()
             #print(mesh.get_max_bound(), mesh.get_min_bound())
             print(volume)
          
             #Save volumes to a text file 
             with open(os.path.join(output_dir, 'volumes.txt'), 'a') as f:
                  f.write(f'z0_{z0_value:.1f}_z1_{z1_value:.1f}: {volume:.1f}\n')