In [1]:
import torch
from reconstruction import AE
from datasets import MeshData
from utils import utils, DataLoader, mesh_sampling, sap
import numpy as np
import pyvista as pv
from skimage import measure
from ipywidgets import interact, interactive, fixed, interact_manual, FloatSlider
from IPython.display import display
import meshplot as mp
import os, sys
from math import ceil
from scipy.ndimage import zoom
import open3d as o3d

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
# Meshplot left an annoying print statement in their code. Using this context manager to supress it...
class HiddenPrints:
    def __enter__(self):
        self._original_stdout = sys.stdout
        sys.stdout = open(os.devnull, 'w')

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout.close()
        sys.stdout = self._original_stdout

In [3]:
device = torch.device('cuda', 0)
# Set the path to the saved model directory
#model_path = "/home/jakaria/torus_bump_500_three_scale_binary_bump_variable_noise_fixed_angle/models_classification_regression_only_correlation_loss/models/65"
model_path = "/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/models/8"
# Load the saved model
model_state_dict = torch.load(f"{model_path}/model_state_dict.pt")
in_channels = torch.load(f"{model_path}/in_channels.pt")
out_channels = torch.load(f"{model_path}/out_channels.pt")
latent_channels = torch.load(f"{model_path}/latent_channels.pt")
spiral_indices_list = torch.load(f"{model_path}/spiral_indices_list.pt")
up_transform_list = torch.load(f"{model_path}/up_transform_list.pt")
down_transform_list = torch.load(f"{model_path}/down_transform_list.pt")
std = torch.load(f"{model_path}/std.pt")
mean = torch.load(f"{model_path}/mean.pt")
template_face = torch.load(f"{model_path}/faces.pt")

# Create an instance of the model
model = AE(in_channels, out_channels, latent_channels,
           spiral_indices_list, down_transform_list,
           up_transform_list)
model.load_state_dict(model_state_dict)
model.to(device)
# Set the model to evaluation mode
model.eval()

AE(
  (en_layers): ModuleList(
    (0): SpiralEnblock(
      (conv): SpiralConv(3, 32, seq_length=9)
    )
    (1-2): 2 x SpiralEnblock(
      (conv): SpiralConv(32, 32, seq_length=9)
    )
    (3): SpiralEnblock(
      (conv): SpiralConv(32, 64, seq_length=9)
    )
    (4): Linear(in_features=6272, out_features=24, bias=True)
  )
  (de_layers): ModuleList(
    (0): Linear(in_features=12, out_features=6272, bias=True)
    (1): SpiralDeblock(
      (conv): SpiralConv(64, 64, seq_length=9)
    )
    (2): SpiralDeblock(
      (conv): SpiralConv(64, 32, seq_length=9)
    )
    (3-4): 2 x SpiralDeblock(
      (conv): SpiralConv(32, 32, seq_length=9)
    )
    (5): SpiralConv(32, 3, seq_length=9)
  )
  (cls_sq): Sequential(
    (0): Linear(in_features=1, out_features=8, bias=True)
    (1): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Linear(in_features=8, out_features=8, bias=True)
    (4): BatchN

In [4]:
z = torch.zeros(12)
plot=None
@mp.interact(**{f'z[{i}]': FloatSlider(min=-1, max=1, step=0.2, value=0) for i in range(12)})
def show(**kwargs):
    global plot
    global z
    z = torch.tensor([kwargs[f'z[{i}]'] for i in range(12)])
    with torch.no_grad():
        z = z.to(device)
        #print(z)
        pred = model.decoder(z)

        reshaped_pred = (pred.view(-1, 3).cpu() * std) + mean
        reshaped_pred = reshaped_pred.cpu().numpy()
        print(reshaped_pred.shape)

    verts = reshaped_pred
    pcd = o3d.io.read_triangle_mesh('/home/jakaria/Explaining_Shape_Variability/src/DeepLearning/compute_canada/guided_vae/data/CoMA/raw/torus/template/template.ply')
    faces = np.asarray(pcd.triangles)

    if plot is None:
        plot = mp.plot(verts, faces, return_plot=True)
    else:
        with HiddenPrints():
            plot.update_object(vertices=verts, faces=faces)
        display(plot._renderer)

interactive(children=(FloatSlider(value=0.0, description='z[0]', max=1.0, min=-1.0, step=0.2), FloatSlider(val…

In [None]:
latent_channels = torch.load(f"{model_path}/latent_channels.pt")
angles = torch.load(f"{model_path}/angles.pt")

In [24]:
import torch

# Sample flattened labels
y_expanded = torch.tensor([[1.0, 0.95, 0.9, 0.2]])
threshold = 0.05001

abs_diff_matrix = torch.abs(y_expanded - y_expanded.t())
same_class_mask = abs_diff_matrix <= threshold

print(abs_diff_matrix)
print(same_class_mask)


tensor([[0.0000, 0.0500, 0.1000, 0.8000],
        [0.0500, 0.0000, 0.0500, 0.7500],
        [0.1000, 0.0500, 0.0000, 0.7000],
        [0.8000, 0.7500, 0.7000, 0.0000]])
tensor([[ True,  True, False, False],
        [ True,  True,  True, False],
        [False,  True,  True, False],
        [False, False, False,  True]])


In [3]:
model_path_root = "/home/jakaria/torus_bump_500_three_scale_binary_bump_variable_noise_fixed_angle/models_classification_regression_only_excitation_inhibition_new/models"
trials = torch.load(f"{model_path_root}/intermediate_trials.pt")

In [4]:
trials

[FrozenTrial(number=0, state=TrialState.COMPLETE, values=[348.6741027832031, 0.06000000000000005, 0.16516362436932752], datetime_start=datetime.datetime(2023, 9, 1, 13, 44, 42, 6978), datetime_complete=datetime.datetime(2023, 9, 1, 13, 51, 3, 867155), params={'epochs': 300, 'batch_size': 24, 'w_cls': 97, 'beta': 0.005369964341189847, 'learning_rate': 0.0003193327361388255, 'learning_rate_decay': 0.74, 'decay_step': 8, 'latent_channels': 12, 'temperature': 141, 'sequence_length': 26, 'dilation': 1, 'out_channel': 8}, user_attrs={}, system_attrs={'nsga2:generation': 0}, intermediate_values={}, distributions={'epochs': IntDistribution(high=400, log=False, low=100, step=100), 'batch_size': IntDistribution(high=32, log=False, low=4, step=4), 'w_cls': IntDistribution(high=100, log=False, low=1, step=1), 'beta': FloatDistribution(high=0.3, log=True, low=0.001, step=None), 'learning_rate': FloatDistribution(high=0.001, log=True, low=0.0001, step=None), 'learning_rate_decay': FloatDistribution(

In [None]:
FrozenTrial(number=18, state=TrialState.COMPLETE, values=[316.8037414550781, 0.12, 0.6956723767897982], datetime_start=datetime.datetime(2023, 9, 1, 15, 25, 40, 838319), datetime_complete=datetime.datetime(2023, 9, 1, 15, 34, 11, 796451), params={'epochs': 400, 'batch_size': 8, 'w_cls': 8, 'beta': 0.05267669943387056, 'learning_rate': 0.0009280828833567735, 'learning_rate_decay': 0.8999999999999999, 'decay_step': 38, 'latent_channels': 12, 'temperature': 1, 'sequence_length': 33, 'dilation': 1, 'out_channel': 16}, user_attrs={}, system_attrs={'nsga2:generation': 0}, intermediate_values={}, distributions={'epochs': IntDistribution(high=400, log=False, low=100, step=100), 'batch_size': IntDistribution(high=32, log=False, low=4, step=4), 'w_cls': IntDistribution(high=100, log=False, low=1, step=1), 'beta': FloatDistribution(high=0.3, log=True, low=0.001, step=None), 'learning_rate': FloatDistribution(high=0.001, log=True, low=0.0001, step=None), 'learning_rate_decay': FloatDistribution(high=0.99, log=False, low=0.7, step=0.01), 'decay_step': IntDistribution(high=50, log=False, low=1, step=1), 'latent_channels': IntDistribution(high=16, log=False, low=12, step=4), 'temperature': IntDistribution(high=181, log=False, low=1, step=20), 'sequence_length': IntDistribution(high=50, log=False, low=5, step=1), 'dilation': IntDistribution(high=2, log=False, low=1, step=1), 'out_channel': IntDistribution(high=32, log=False, low=8, step=8)}, trial_id=18, value=None),
