In [1]:
import torch
from reconstruction import AE
from datasets import MeshData
from utils import utils, DataLoader, mesh_sampling, sap
import numpy as np
import pyvista as pv
from skimage import measure
from ipywidgets import interact, interactive, fixed, interact_manual, FloatSlider
from IPython.display import display
import meshplot as mp
import os, sys
from math import ceil
from scipy.ndimage import zoom
import open3d as o3d
from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
# Meshplot left an annoying print statement in their code. Using this context manager to supress it...
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [3]:
device = torch.device('cuda', 1)
# Set the path to the saved model directory
#model_path = "/home/jakaria/torus_bump_500_three_scale_binary_bump_variable_noise_fixed_angle/models_classification_regression_contrastive_loss_only/models/199"
model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_attribute_age_limited_tle"
# Load the saved model
model_state_dict = torch.load(f"{model_path}/model_state_dict.pt")
in_channels = torch.load(f"{model_path}/in_channels.pt")
out_channels = torch.load(f"{model_path}/out_channels.pt")
latent_channels = torch.load(f"{model_path}/latent_channels.pt")
spiral_indices_list = torch.load(f"{model_path}/spiral_indices_list.pt")
up_transform_list = torch.load(f"{model_path}/up_transform_list.pt")
down_transform_list = torch.load(f"{model_path}/down_transform_list.pt")
std = torch.load(f"{model_path}/std.pt")
mean = torch.load(f"{model_path}/mean.pt")
template_face = torch.load(f"{model_path}/faces.pt")

# Create an instance of the model
model = AE(in_channels, out_channels, latent_channels,
           spiral_indices_list, down_transform_list,
           up_transform_list)
model.load_state_dict(model_state_dict)
model.to(device)
# Set the model to evaluation mode
model.eval()

AE(
  (en_layers): ModuleList(
    (0): SpiralEnblock(
      (conv): SpiralConv(3, 16, seq_length=9)
    )
    (1-2): 2 x SpiralEnblock(
      (conv): SpiralConv(16, 16, seq_length=9)
    )
    (3): SpiralEnblock(
      (conv): SpiralConv(16, 32, seq_length=9)
    )
    (4): Linear(in_features=5696, out_features=24, bias=True)
  )
  (de_layers): ModuleList(
    (0): Linear(in_features=12, out_features=5696, bias=True)
    (1): SpiralDeblock(
      (conv): SpiralConv(32, 32, seq_length=9)
    )
    (2): SpiralDeblock(
      (conv): SpiralConv(32, 16, seq_length=9)
    )
    (3-4): 2 x SpiralDeblock(
      (conv): SpiralConv(16, 16, seq_length=9)
    )
    (5): SpiralConv(16, 3, seq_length=9)
  )
  (cls_sq): Sequential(
    (0): Linear(in_features=1, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=8, out_features=8, bias=True)
    (4): BatchN

In [11]:
# Function to calculate the magnitude of change for each vertex
def calculate_magnitude_change(verts, initial_verts):
    diff = np.linalg.norm(verts - initial_verts, axis=1)
    print(diff)
    return diff

# Function to map magnitude to colors
# Define colors for specific distance ranges
def map_magnitude_to_colors(magnitude):
    colors = []
    for dist in magnitude:
        if 0.00 <= dist < 0.001:
            colors.append([0, 0, 1])         # Blue
        elif 0.001 <= dist < 0.002:
            colors.append([0.4, 0.4, 1])     # Light Blue
        elif 0.002 <= dist < 0.005:
            colors.append([0.4, 0.4, 1])         # White
        elif 0.005 <= dist < 0.01:
            colors.append([0.4, 0.4, 1])       # Light Yellow
        else:
            colors.append([1, 1, 0])         # Yellow
    return colors
# Function to move an object to the center
def move_to_center(verts):
    centroid = np.mean(verts, axis=0)
    return verts - centroid

def calculate_volume_voxelization(mesh):
    voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh, voxel_size=0.001)
    voxel_count = len(voxel_grid.get_voxels())
    voxel_volume = voxel_grid.voxel_size ** 3
    volume = voxel_count * voxel_volume
    return volume

rotation_matrix = np.array([[np.cos(np.pi), -np.sin(np.pi), 0],
                            [np.sin(np.pi), np.cos(np.pi), 0],
                            [0, 0, 1]]) #z axic
z = torch.zeros(12)
with torch.no_grad():
    z = z.to(device)
    #print(z)
    pred = model.decoder(z)

    reshaped_pred_initial = (pred.view(-1, 3).cpu() * std) + mean
    reshaped_pred_initial = reshaped_pred_initial.cpu().numpy()
    #print(reshaped_pred.shape)

verts_initial = reshaped_pred_initial
#verts = o3d.utility.Vector3dVector(np.dot(np.asarray(verts), rotation_matrix.T))
verts_initial = np.dot(np.asarray(verts_initial), rotation_matrix.T)

plot=None
sliders = {f'z[{i}]': FloatSlider(min=-3.0, max=3.0, step=0.5, value=0) for i in range(12)}
sliders['z[0]'].description = 'Disease'
sliders['z[1]'].description = 'Age'

@mp.interact(**sliders)
#@mp.interact(**{f'z[{i}]': FloatSlider(min=-2.5, max=2.5, step=0.4, value=0) for i in range(12)})
def show(**kwargs):
    global plot
    global z
    z = torch.tensor([kwargs[f'z[{i}]'] for i in range(12)])
    with torch.no_grad():
        z = z.to(device)
        #print(z)
        pred = model.decoder(z)

        reshaped_pred = (pred.view(-1, 3).cpu() * std) + mean
        reshaped_pred = reshaped_pred.cpu().numpy()
        #print(reshaped_pred.shape)

    verts = reshaped_pred
    #verts = o3d.utility.Vector3dVector(np.dot(np.asarray(verts), rotation_matrix.T))
    verts = np.dot(np.asarray(verts), rotation_matrix.T)

    pcd = o3d.io.read_triangle_mesh('/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/template/template.ply')
    faces = np.asarray(pcd.triangles)
    initial_verts = np.asarray(pcd.vertices)
    #print(verts)
    #print(faces)
    '''
    # Define a threshold for separating the objects based on x-values
    x_threshold = 0.0

    # Separate vertices and faces for the first object (x < x_threshold)
    verts_object1 = verts[verts[:, 0] < x_threshold]
    faces_object1 = [face for face in faces if (verts[face, 0] < x_threshold).all()]

    # Separate vertices and faces for the second object (x >= x_threshold)
    verts_object2 = verts[verts[:, 0] >= x_threshold]
    faces_object2 = [face for face in faces if (verts[face, 0] >= x_threshold).all()]

    # Convert lists to NumPy arrays
    verts_object1 = np.array(verts_object1)
    faces_object1 = np.array(faces_object1)
    verts_object2 = np.array(verts_object2)
    faces_object2 = np.array(faces_object2)


    print(verts_object1.shape)
    print(faces_object1.shape)

    verts_object1 = move_to_center(verts_object1)

    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(verts_object2)
    mesh.triangles = o3d.utility.Vector3iVector(faces_object2)
    volume = calculate_volume_voxelization(mesh)
    print(volume)

    #o3d.visualization.draw_geometries([verts_object1])
    #o3d.visualization.draw_plotly([verts_object1])
    
    #with np.printoptions(threshold=np.inf):
        #print(verts_object1)
    
    '''

    # Calculate magnitude of change
    magnitude = calculate_magnitude_change(verts, verts_initial)

    # Map magnitude to colors
    colors = map_magnitude_to_colors(magnitude)
    colors = np.asarray(colors)


    white_color = [1.0, 1.0, 1.0]
    grey_color = [0.5, 0.5, 0.5]

    if plot is None:
        #plot = mp.plot(verts_object1)
        plot = mp.plot(verts, faces, c=colors, return_plot=True)
    else:
        with HiddenPrints():
            #plot.update_object(vertices=verts, faces=faces)
            plot.update_object(vertices=verts, faces=faces, colors=colors)
        display(plot._renderer)


interactive(children=(FloatSlider(value=0.0, description='Disease', max=3.0, min=-3.0, step=0.5), FloatSlider(…

In [6]:
# Function to move an object to the center
def move_to_center(verts):
    centroid = np.mean(verts, axis=0)
    return verts - centroid

def calculate_volume_voxelization(mesh):
    voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh, voxel_size=0.1)
    voxel_count = len(voxel_grid.get_voxels())
    voxel_volume = voxel_grid.voxel_size ** 3
    volume = voxel_count * voxel_volume
    return volume

rotation_matrix = np.array([[np.cos(np.pi), -np.sin(np.pi), 0],
                            [np.sin(np.pi), np.cos(np.pi), 0],
                            [0, 0, 1]]) #z axic
z = torch.zeros(12)

In [7]:
latent_channels = torch.load(f"{model_path}/latent_channels.pt")
angles = torch.load(f"{model_path}/angles.pt")

FileNotFoundError: [Errno 2] No such file or directory: '/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_attribute_age_limited_tle/angles.pt'

In [5]:
import torch

# Sample flattened labels
y_expanded = torch.tensor([[1.0, 0.95, 0.9, 0.2]])
threshold = 0.05001

abs_diff_matrix = torch.abs(y_expanded - y_expanded.t())
same_class_mask = abs_diff_matrix <= threshold

print(abs_diff_matrix)
print(same_class_mask)


tensor([[0.0000, 0.0500, 0.1000, 0.8000],
        [0.0500, 0.0000, 0.0500, 0.7500],
        [0.1000, 0.0500, 0.0000, 0.7000],
        [0.8000, 0.7500, 0.7000, 0.0000]])
tensor([[ True,  True, False, False],
        [ True,  True,  True, False],
        [False,  True,  True, False],
        [False, False, False,  True]])


In [None]:
model_path_root = "//home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/models_contrastive_inhib"
trials = torch.load(f"{model_path_root}/intermediate_trials.pt")
trials

In [7]:
import numpy as np
import os
import torch
import meshplot as mp
from ipywidgets import interact, FloatSlider
from scipy.spatial import distance
import open3d as o3d
from skimage import measure
from contextlib import contextmanager


In [8]:
# Function to calculate distance between two point clouds (meshes)
def calculate_distance(mesh1, mesh2):
    return distance.directed_hausdorff(mesh1.vertices, mesh2.vertices)[0]

# Create a directory to save plots and results
output_dir = "/home/jakaria/save_plots_ms_range"
os.makedirs(output_dir, exist_ok=True)

# Function to save plots
def save_plot(plot, filename):
    plot.save(os.path.join(output_dir, filename))

def calculate_volume_voxelization(mesh):
    voxel_grid = o3d.geometry.VoxelGrid.create_from_triangle_mesh(mesh, voxel_size=0.001)
    voxel_count = len(voxel_grid.get_voxels())
    voxel_volume = voxel_grid.voxel_size ** 3
    volume = voxel_count * voxel_volume
    return volume

white_color = [1.0, 1.0, 1.0]
grey_color = [0.5, 0.5, 0.5]

In [9]:
# Function to visualize, save plots, and calculate distances and volumes
def visualize_and_save(z0, z1):
    global plot
    global z

    z[0] = z0
    z[1] = z1

    with torch.no_grad():
        z = z.to(device)
        pred = model.decoder(z)

        reshaped_pred = (pred.view(-1, 3).cpu() * std) + mean
        verts = reshaped_pred.cpu().numpy()

    verts = np.dot(np.asarray(verts), rotation_matrix.T)

    pcd = o3d.io.read_triangle_mesh('/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/template/template.ply')
    faces = np.asarray(pcd.triangles)

    white_color = [1.0, 1.0, 1.0]
    grey_color = [0.5, 0.5, 0.5]
     
    #mp.subplot(verts, faces, c=np.array(white_color), s=[2, 2, 0])
    #plot = mp.plot(verts, faces, c=np.array(white_color), return_plot=True)
    #save_plot(plot, f'z0_{z0:.1f}_z1_{z1:.1f}.png')
    #print(verts.shape, faces.shape)
    return verts, faces

In [12]:
# Sample and save plots for different z[0] values while fixing z[1]
verts_all = []
faces_all = []
for z1_value in np.arange(-3, 3, 0.5):
        for z0_value in [-1.5, 1.5]:
             verts, faces = visualize_and_save(z0_value, z1_value)
             verts_all.append(verts)
             faces_all.append(faces)
             scalar_values = np.array([60, 60, 50])
             verts = verts * scalar_values
             print(max(verts[:, 0]), min(verts[:, 0]))
             print(max(verts[:, 1]), min(verts[:, 1]))
             print(max(verts[:, 2]), min(verts[:, 2]))
             mesh = o3d.geometry.TriangleMesh()
             mesh.vertices = o3d.utility.Vector3dVector(verts)
             mesh.triangles = o3d.utility.Vector3iVector(faces)

             print(np.asanyarray(mesh.vertices).shape)
             print(mesh.is_orientable())
             print(mesh.is_watertight())
             
             #mesh.orient_triangles()
             #volume = calculate_volume_voxelization(mesh)
             print(mesh.get_volume())
             volume = mesh.get_volume()
             #print(mesh.get_max_bound(), mesh.get_min_bound())
             print(volume)
          
             #Save volumes to a text file 
             with open(os.path.join(output_dir, 'volumes.txt'), 'a') as f:
                  f.write(f'z0_{z0_value:.1f}_z1_{z1_value:.1f}: {volume:.1f}\n')

29.769492745399475 -30.13582706451416
20.611996650695804 -20.85673034191132
4.492974653840065 -4.3707940727472305
(6378, 3)
True
True
4946.27348786492
4946.27348786492
29.61395502090454 -30.13048768043518
20.264525413513187 -21.088930964469906
4.90114688873291 -4.380163922905922
(6378, 3)
True
True
5029.688542521696
5029.688542521696
29.764416217803955 -30.14093041419983
20.53082942962647 -20.774751305580143
4.456835985183716 -4.362678527832031
(6378, 3)
True
True
4895.500491917255
4895.500491917255
29.604247212409973 -30.145111083984375
20.146316885948185 -21.02820217609406
4.8659637570381165 -4.3561723083257675
(6378, 3)
True
True
4973.7629018845255
4973.7629018845255
29.759291410446167 -30.146069526672363
20.44864654541016 -20.694925189018253
4.4193219393491745 -4.356484487652779
(6378, 3)
True
True
4844.807899147368
4844.807899147368
29.59429621696472 -30.16010284423828
20.029715895652775 -20.97084760665894
4.828977212309837 -4.333098977804184
(6378, 3)
True
True
4917.975151073241


In [None]:
# Sample and save plots for different z0 and z1 values
volume_z0_m1 = []
volume_z0_p1 = []
volume_differences = []

# Define the range and increment for z1 values
z1_start = -1.2
z1_end = 1.3
z1_increment = 0.02

# Define the ranges for which you want to calculate average volumes
z1_ranges = [
    (-1.2, -0.8),
    (-0.8, -0.4),
    (-0.4, 0.0),
    (0.0, 0.4),
    (0.4, 0.8),
    (0.8, 1.2),
]

for z0_value in [-1.5, 1.5]:
    for z1_range in z1_ranges:
        z1_range_start, z1_range_end = z1_range
        volumes_for_z1_range = []

        for z1_value in np.arange(z1_range_start, z1_range_end + z1_increment, z1_increment):
            verts, faces = visualize_and_save(z0_value, z1_value)
            #verts_all.append(verts)
            #faces_all.append(faces)
            mesh = o3d.geometry.TriangleMesh()
            mesh.vertices = o3d.utility.Vector3dVector(verts)
            mesh.triangles = o3d.utility.Vector3iVector(faces)
            volume = mesh.get_volume()*316800
            volumes_for_z1_range.append(volume)

        # Calculate and print the average volume for the current z1 range
        average_volume = sum(volumes_for_z1_range) / len(volumes_for_z1_range)

        if z0_value == -1.5:
            volume_z0_m1.append(average_volume)

        if z0_value == 1.5:
            volume_z0_p1.append(average_volume)
        
        # Print and save the average volume without detailed range information
        print(f'z0_{z0_value:.1f}_z1_range_{z1_range_start:.2f}_{z1_range_end:.2f} (Average Volume): {average_volume:.1f}')

        # Save average volumes to a text file
        with open(os.path.join(output_dir, 'average_volumes.txt'), 'a') as f:
            f.write(f'z0_{z0_value:.1f}_z1_range_{z1_range_start:.2f}_{z1_range_end:.2f} (Average Volume): {average_volume:.1f}\n')

# Calculate and store individual volume differences for each range
for i in range(len(z1_ranges)):
    volume_difference = volume_z0_p1[i] - volume_z0_m1[i]
    volume_differences.append(volume_difference)
    print(volume_differences)


In [None]:
volume_differences = []
for i in range(len(z1_ranges)):
    volume_difference = -(volume_z0_m1[i] - volume_z0_p1[i])
    volume_differences.append(volume_difference)
print(volume_differences)

In [None]:
import matplotlib.pyplot as plt
# Create custom x-axis labels
#TLE 18 to 73
age_ranges = [
    (18, 26),
    (27, 35),
    (36, 44),
    (45, 53),
    (54, 62),
    (63, 71),
]
# Create custom x-axis labels
x_axis_labels = [f'{start}-{end}' for start, end in age_ranges]
volume_differences
plt.figure(figsize=(10, 6))
plt.plot(volume_differences, marker='o', linestyle='-', color='b')
plt.title('Volume Differences vs. Age Ranges')
plt.xlabel('Age Range')
plt.ylabel('Volume Differences ($mm^3$)')
plt.xticks(range(len(z1_ranges)), x_axis_labels, rotation=45)
plt.grid(True)
plt.tight_layout()

# Show the plot
plt.show()

In [None]:
for i in range(2):
    mp.plot(verts_all[i], faces_all[i], c=np.array(white_color))

In [None]:
# Define file paths for saving vertices and faces
vertices_file = os.path.join(output_dir,"mesh_vertices.npy")
faces_file = os.path.join(output_dir,"mesh_faces.npy")


# Save vertices and faces to separate .npy files
np.save(vertices_file, verts_all)
np.save(faces_file, faces_all)

In [None]:
# Calculate the distances between corresponding vertices
distances = np.linalg.norm(verts_all[1] - verts_all[0], axis=1)

# Create Open3D TriangleMesh objects
mesh1 = o3d.geometry.TriangleMesh()
mesh1.vertices = o3d.utility.Vector3dVector(verts_all[1])
mesh1.compute_vertex_normals()

mesh2 = o3d.geometry.TriangleMesh()
mesh2.vertices = o3d.utility.Vector3dVector(verts_all[0])
mesh2.compute_vertex_normals()

# Create a color mapping based on distances
# You can adjust the color thresholds as needed
color_map = o3d.visualization.draw_plotly([mesh1, mesh2])

# Normalize distances to the range [0, 1] for colormap mapping
normalized_distances = (distances - distances.min()) / (distances.max() - distances.min())

# Map distances to colors using the colormap
colors = np.asarray(color_map(normalized_distances))[:, :3]  # Extract RGB values

# Set the base color of mesh2 to white (or any other base color)
mesh2.paint_uniform_color([1, 1, 1])

# Assign the calculated colors to the faces of mesh2
mesh2.vertex_colors = o3d.utility.Vector3dVector(colors)

# Visualize both meshes
o3d.visualization.draw_plotly([mesh1, mesh2])

In [None]:

# Calculate distance between meshes and create color-coded plots
z[1] = 0.2  # Set a specific value for z[1]
distances = np.zeros((len(np.arange(-1, 1.2, 0.2)), len([-1, 1])))
for i, z0_value in enumerate([-1, 1]):
    for j, z1_value in enumerate(np.arange(-1, 1.2, 0.2)):
        visualize_and_save(z0_value, z1_value)
        mesh1 = o3d.geometry.TriangleMesh.create_from_points(verts)
        mesh2 = o3d.geometry.TriangleMesh.create_from_points(verts)
        distances[j, i] = calculate_distance(mesh1, mesh2)

# Save distances as a color-coded plot
import matplotlib.pyplot as plt
plt.imshow(distances, cmap='viridis', origin='lower', extent=[-1, 1, -1, 1])
plt.colorbar(label='Distance')
plt.xlabel('z[0]')
plt.ylabel('z[1]')
plt.title('Distance Heatmap')
plt.savefig(os.path.join(output_dir, 'distance_heatmap.png'))


In [None]:
np.interp(0.77, [0,1], [1,-1])

In [None]:
np.interp(0.15, [0,1], [1,-1])

In [None]:
test = torch.load('/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/hippocampus/labels.pt')
print(test)