In [2]:
import torch
from reconstruction import AE
from datasets import MeshData
from utils import utils, DataLoader, mesh_sampling, sap
import numpy as np
import pyvista as pv
from skimage import measure
from ipywidgets import interact, interactive, fixed, interact_manual, FloatSlider
from IPython.display import display
import meshplot as mp
import os, sys
from math import ceil
from scipy.ndimage import zoom
import open3d as o3d

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
# Meshplot left an annoying print statement in their code. Using this context manager to supress it...
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [3]:
device = torch.device('cuda', 0)
# Set the path to the saved model directory
#model_path = "/home/jakaria/torus_bump_500_three_scale_binary_bump_variable_noise_fixed_angle/models_classification_regression_only_correlation_loss/models/65"
model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_corr/7"
# Load the saved model
model_state_dict = torch.load(f"{model_path}/model_state_dict.pt")
in_channels = torch.load(f"{model_path}/in_channels.pt")
out_channels = torch.load(f"{model_path}/out_channels.pt")
latent_channels = torch.load(f"{model_path}/latent_channels.pt")
spiral_indices_list = torch.load(f"{model_path}/spiral_indices_list.pt")
up_transform_list = torch.load(f"{model_path}/up_transform_list.pt")
down_transform_list = torch.load(f"{model_path}/down_transform_list.pt")
std = torch.load(f"{model_path}/std.pt")
mean = torch.load(f"{model_path}/mean.pt")
template_face = torch.load(f"{model_path}/faces.pt")

# Create an instance of the model
model = AE(in_channels, out_channels, latent_channels,
           spiral_indices_list, down_transform_list,
           up_transform_list)
model.load_state_dict(model_state_dict)
model.to(device)
# Set the model to evaluation mode
model.eval()

AE(
  (en_layers): ModuleList(
    (0): SpiralEnblock(
      (conv): SpiralConv(3, 24, seq_length=9)
    )
    (1-2): 2 x SpiralEnblock(
      (conv): SpiralConv(24, 24, seq_length=9)
    )
    (3): SpiralEnblock(
      (conv): SpiralConv(24, 48, seq_length=9)
    )
    (4): Linear(in_features=4704, out_features=24, bias=True)
  )
  (de_layers): ModuleList(
    (0): Linear(in_features=12, out_features=4704, bias=True)
    (1): SpiralDeblock(
      (conv): SpiralConv(48, 48, seq_length=9)
    )
    (2): SpiralDeblock(
      (conv): SpiralConv(48, 24, seq_length=9)
    )
    (3-4): 2 x SpiralDeblock(
      (conv): SpiralConv(24, 24, seq_length=9)
    )
    (5): SpiralConv(24, 3, seq_length=9)
  )
  (cls_sq): Sequential(
    (0): Linear(in_features=1, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=8, out_features=8, bias=True)
    (4): BatchN

In [8]:
z = torch.zeros(12)
plot=None
@mp.interact(**{f'z[{i}]': FloatSlider(min=-1.5, max=1.5, step=0.2, value=0) for i in range(12)})
def show(**kwargs):
    global plot
    global z
    z = torch.tensor([kwargs[f'z[{i}]'] for i in range(12)])
    with torch.no_grad():
        z = z.to(device)
        #print(z)
        pred = model.decoder(z)

        reshaped_pred = (pred.view(-1, 3).cpu() * std) + mean
        reshaped_pred = reshaped_pred.cpu().numpy()
        print(reshaped_pred.shape)

    verts = reshaped_pred
    pcd = o3d.io.read_triangle_mesh('/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/template/template.ply')
    faces = np.asarray(pcd.triangles)

    if plot is None:
        plot = mp.plot(verts, faces, return_plot=True)
    else:
        with HiddenPrints():
            plot.update_object(vertices=verts, faces=faces)
        display(plot._renderer)

interactive(children=(FloatSlider(value=0.0, description='z[0]', max=1.5, min=-1.5, step=0.2), FloatSlider(val…

In [None]:
latent_channels = torch.load(f"{model_path}/latent_channels.pt")
angles = torch.load(f"{model_path}/angles.pt")

In [24]:
import torch

# Sample flattened labels
y_expanded = torch.tensor([[1.0, 0.95, 0.9, 0.2]])
threshold = 0.05001

abs_diff_matrix = torch.abs(y_expanded - y_expanded.t())
same_class_mask = abs_diff_matrix <= threshold

print(abs_diff_matrix)
print(same_class_mask)


tensor([[0.0000, 0.0500, 0.1000, 0.8000],
        [0.0500, 0.0000, 0.0500, 0.7500],
        [0.1000, 0.0500, 0.0000, 0.7000],
        [0.8000, 0.7500, 0.7000, 0.0000]])
tensor([[ True,  True, False, False],
        [ True,  True,  True, False],
        [False,  True,  True, False],
        [False, False, False,  True]])


In [7]:
model_path_root = "/home/jakaria/torus_bump_5000_two_scale_binary_bump_variable_noise_fixed_angle/models_con_inhib_19-11-23_without_lambda1-2/"
trials = torch.load(f"{model_path_root}/intermediate_trials.pt")

In [8]:
trials[196]

FrozenTrial(number=196, state=TrialState.COMPLETE, values=[357.08770751953125, 0.41000000000000003, 0.9592963793411143], datetime_start=datetime.datetime(2023, 11, 14, 6, 35, 52, 595268), datetime_complete=datetime.datetime(2023, 11, 14, 6, 40, 25, 825007), params={'epochs': 400, 'batch_size': 4, 'w_cls': 34, 'beta': 0.21125819643916915, 'learning_rate': 0.00019176245642204008, 'learning_rate_decay': 0.72, 'delta': 0.30000000000000004, 'decay_step': 14, 'latent_channels': 16, 'temperature': 81, 'sequence_length': 30, 'dilation': 2, 'out_channel': 24}, user_attrs={}, system_attrs={'nsga2:generation': 3}, intermediate_values={}, distributions={'epochs': IntDistribution(high=400, log=False, low=100, step=100), 'batch_size': IntDistribution(high=32, log=False, low=4, step=4), 'w_cls': IntDistribution(high=100, log=False, low=1, step=1), 'beta': FloatDistribution(high=0.3, log=True, low=0.001, step=None), 'learning_rate': FloatDistribution(high=0.001, log=True, low=0.0001, step=None), 'lear

In [4]:
model_path_root = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models_con_inhib_lambda_n/"
trials = torch.load(f"{model_path_root}/intermediate_trials.pt")

In [6]:
trials

[FrozenTrial(number=0, state=TrialState.COMPLETE, values=[356.6120910644531, 0.32599999999999996, 0.951632621677345], datetime_start=datetime.datetime(2023, 11, 19, 17, 58, 53, 65275), datetime_complete=datetime.datetime(2023, 11, 19, 18, 3, 47, 101416), params={'lambda1': 0.55}, user_attrs={}, system_attrs={'nsga2:generation': 0}, intermediate_values={}, distributions={'lambda1': FloatDistribution(high=0.95, log=False, low=0.05, step=0.05)}, trial_id=0, value=None),
 FrozenTrial(number=1, state=TrialState.COMPLETE, values=[357.5674133300781, 0.29000000000000004, 0.9479491074810836], datetime_start=datetime.datetime(2023, 11, 19, 18, 3, 47, 102031), datetime_complete=datetime.datetime(2023, 11, 19, 18, 8, 34, 873927), params={'lambda1': 0.15000000000000002}, user_attrs={}, system_attrs={'nsga2:generation': 0}, intermediate_values={}, distributions={'lambda1': FloatDistribution(high=0.95, log=False, low=0.05, step=0.05)}, trial_id=1, value=None),
 FrozenTrial(number=2, state=TrialState.