# Latent Mesh Explorer

Interactively decode MeshDecoder latent vectors, tweak a single coefficient, and inspect the resulting mesh wireframe.


In [1]:
from pathlib import Path
import sys
from typing import Dict

import numpy as np
import torch
from torch.nn import Embedding
# from torch.serialization import add_safe_globals

import plotly.graph_objects as go
from plotly.subplots import make_subplots

import ipywidgets as widgets
from IPython.display import display
from pytorch3d.structures import Meshes
# import subdivide meshes 
from pytorch3d.ops import subdivide_meshes

PROJECT_ROOT = Path('/home/ralbe/DALS/mesh_autodecoder').resolve()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from model.mesh_decoder import MeshDecoder

CHECKPOINT_PATH = Path('/home/ralbe/DALS/mesh_autodecoder/models/MeshDecoderTrainer_2025-11-06_12-00-26.ckpt')
LATENT_DIR = Path('/home/ralbe/DALS/mesh_autodecoder/inference_results/meshes_MeshDecoderTrainer_2025-11-06_12-00-26/latents')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# add_safe_globals([Embedding])


In [2]:
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
hparams: Dict = checkpoint['hparams']

decoder = MeshDecoder(
    latent_features=hparams['latent_features'],
    steps=hparams['steps'],
    hidden_features=hparams['hidden_features'],
    subdivide=hparams['subdivide'],
    mode=hparams['decoder_mode'],
    norm=hparams['normalization'][0],
)
decoder.load_state_dict(checkpoint['decoder_state_dict'])
decoder.to(DEVICE)
decoder.eval()

template: Meshes = checkpoint['template'].to(DEVICE)
subdiv = subdivide_meshes.SubdivideMeshes()

template = subdiv(template).to(DEVICE)
latent_dim = hparams['latent_features']


print(f'Device: {DEVICE}')
print(f'Latent dimension: {latent_dim}')


Device: cpu
Latent dimension: 128


  self._edges_packed = torch.stack([u // V, u % V], dim=1)


In [3]:
latent_cache: Dict[Path, torch.Tensor] = {}

def load_latent_tensor(path: Path) -> torch.Tensor:
    if path not in latent_cache:
        tensor = torch.load(path, map_location='cpu')
        if isinstance(tensor, dict):
            for key in ('latent', 'latent_vector', 'latent_vectors'):
                if key in tensor:
                    tensor = tensor[key]
                    break
            else:
                raise KeyError(f'No latent vector found in {path}')
        tensor = torch.as_tensor(tensor).float()
        if tensor.ndim == 2:
            tensor = tensor.squeeze(0)
        latent_cache[path] = tensor.clone()
    return latent_cache[path].clone()

def decode_latent(latent_1d: torch.Tensor) -> Meshes:
    latent_batch = latent_1d.unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        decoded = decoder(template.clone(), latent_batch)[-1]
    return decoded.cpu()

def mesh_bounds(mesh: Meshes) -> np.ndarray:
    verts = mesh.verts_packed()
    return verts.min(dim=0).values.cpu().numpy(), verts.max(dim=0).values.cpu().numpy()

def mesh_to_wireframe_trace(mesh: Meshes, color: str, name: str) -> go.Scatter3d:
    mesh_cpu = mesh.cpu()
    verts = mesh_cpu.verts_packed().numpy()
    edges = mesh_cpu.edges_packed().numpy()
    coords = np.empty((3 * len(edges), 3), dtype=np.float32)
    coords[0::3] = verts[edges[:, 0]]
    coords[1::3] = verts[edges[:, 1]]
    coords[2::3] = np.nan
    return go.Scatter3d(
        x=coords[:, 0],
        y=coords[:, 1],
        z=coords[:, 2],
        mode='lines',
        name=name,
        line=dict(color=color, width=2),
        hoverinfo='skip',
    )

def make_comparison_figure(
    original: Meshes,
    modified: Meshes,
    latent_index: int,
    new_value: float,
    original_value: float,
) -> go.Figure:
    o_min, o_max = mesh_bounds(original)
    m_min, m_max = mesh_bounds(modified)

    def axis_ranges(mins: np.ndarray, maxs: np.ndarray):
        center = 0.5 * (mins + maxs)
        span = max(maxs - mins)
        half = max(span / 2.0, 1e-3)
        return [[center[i] - half, center[i] + half] for i in range(3)]

    original_ranges = axis_ranges(o_min, o_max)
    combined_ranges = axis_ranges(np.minimum(o_min, m_min), np.maximum(o_max, m_max))

    fig = make_subplots(
        rows=1,
        cols=2,
        specs=[[{'type': 'scene'}, {'type': 'scene'}]],
        subplot_titles=(
            f'Original (index {latent_index} = {original_value:+.4f})',
            f'Overlay (index {latent_index} → {new_value:+.4f})',
        ),
    )

    fig.add_trace(mesh_to_wireframe_trace(original, '#636EFA', 'Original'), row=1, col=1)

    baseline_trace = mesh_to_wireframe_trace(original, '#9CA3FF', 'Original baseline')
    baseline_trace.line['width'] = 1
    fig.add_trace(baseline_trace, row=1, col=2)
    fig.add_trace(mesh_to_wireframe_trace(modified, '#EF553B', 'Modified'), row=1, col=2)

    scene_original = dict(
        xaxis=dict(visible=False, range=combined_ranges[0]),
        yaxis=dict(visible=False, range=combined_ranges[1]),
        zaxis=dict(visible=False, range=combined_ranges[2]),
        aspectmode='cube',
    )
    scene_overlay = dict(
        xaxis=dict(visible=False, range=original_ranges[0]),
        yaxis=dict(visible=False, range=original_ranges[1]),
        zaxis=dict(visible=False, range=original_ranges[2]),
        aspectmode='cube',
    )
    fig.update_layout(
        height=600,
        width=1100,
        margin=dict(l=0, r=0, t=60, b=0),
        scene=scene_original,
        scene2=scene_overlay,
        showlegend=False,
    )
    return fig


In [4]:
latent_files = sorted(LATENT_DIR.glob('*_latent.pt'))
if not latent_files:
    raise FileNotFoundError(f'No latent tensors found in {LATENT_DIR}')

latent_dropdown = widgets.Dropdown(
    options=[(path.name, str(path)) for path in latent_files],
    value=str(latent_files[0]),
    description='Latent file:',
    layout=widgets.Layout(width='60%'),
)

index_text = widgets.BoundedIntText(
    min=0,
    max=latent_dim - 1,
    value=0,
    step=1,
    description='Index:',
    layout=widgets.Layout(width='120px'),
    style={'description_width': 'initial'},
)

value_slider = widgets.FloatSlider(
    min=-0.4,
    max=0.4,
    step=0.01,
    value=0.0,
    description='New value:',
    continuous_update=False,
)

info_label = widgets.HTML()


def render(latent_path: str, latent_index: int, slider_value: float):
    latent_tensor = load_latent_tensor(Path(latent_path))
    # Clamp latent_index to allowed range just in case
    latent_index = int(np.clip(latent_index, 0, latent_tensor.shape[0] - 1))
    original_value = float(latent_tensor[latent_index])
    info_label.value = f'<b>Original value:</b> {original_value:+.4f}'

    modified_tensor = latent_tensor.clone()
    modified_tensor[latent_index] = slider_value

    original_mesh = decode_latent(latent_tensor)
    modified_mesh = decode_latent(modified_tensor)

    fig = make_comparison_figure(
        original_mesh,
        modified_mesh,
        latent_index=latent_index,
        new_value=slider_value,
        original_value=original_value,
    )
    display(fig)


def handle_latent_change(change):
    latent_tensor = load_latent_tensor(Path(change['new']))
    index_text.max = latent_tensor.shape[0] - 1
    index_text.value = 0
    base_value = float(latent_tensor[0])
    clipped = float(np.clip(base_value, value_slider.min, value_slider.max))
    if abs(value_slider.value - clipped) > 1e-6:
        value_slider.value = clipped
    else:
        info_label.value = f'<b>Original value:</b> {base_value:+.4f}'


def handle_index_change(change):
    latent_tensor = load_latent_tensor(Path(latent_dropdown.value))
    new_index = change['new']
    # Clamp just in case
    if new_index < 0:
        new_index = 0
    elif new_index > latent_tensor.shape[0] - 1:
        new_index = latent_tensor.shape[0] - 1
    base_value = float(latent_tensor[new_index])
    info_label.value = f'<b>Original value:</b> {base_value:+.4f}'
    clipped = float(np.clip(base_value, value_slider.min, value_slider.max))
    if abs(value_slider.value - clipped) > 1e-6:
        value_slider.value = clipped

latent_dropdown.observe(handle_latent_change, names='value')
index_text.observe(handle_index_change, names='value')

interactive_plot = widgets.interactive_output(
    render,
    {
        'latent_path': latent_dropdown,
        'latent_index': index_text,
        'slider_value': value_slider,
    },
)

controls = widgets.VBox([
    latent_dropdown,
    widgets.HBox([index_text, value_slider]),
    info_label,
])

handle_latent_change({'new': latent_dropdown.value})
display(controls, interactive_plot)


VBox(children=(Dropdown(description='Latent file:', layout=Layout(width='60%'), options=(('cirrhotic_115_testi…

Output()

In [5]:
from pathlib import Path
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from analysis.latent_metrics import (
    baseline_latent_vector,
    compute_deformation_metrics,
    decode_latent,
    load_latent_tensor,
)

baseline_latent = baseline_latent_vector(checkpoint).to(DEVICE)
baseline_mesh = decode_latent(decoder, template, baseline_latent, DEVICE)
baseline_verts = baseline_mesh.verts_packed().cpu().numpy()
baseline_faces = baseline_mesh.faces_packed().cpu().numpy()

selected_latent_path = Path(latent_dropdown.value)
selected_latent = load_latent_tensor(selected_latent_path).to(DEVICE)
target_mesh = decode_latent(decoder, template, selected_latent, DEVICE)
deformation_metrics = compute_deformation_metrics(baseline_mesh, target_mesh)

current_latent_path = selected_latent_path
current_target_mesh = target_mesh
current_deformation = deformation_metrics

verts = target_mesh.verts_packed().cpu().numpy()
faces = target_mesh.faces_packed().cpu().numpy()
displacement_norm = deformation_metrics.displacement_norm
area_change = deformation_metrics.vertex_metrics["area_change"]
stretch = deformation_metrics.vertex_metrics["kappa"]

fig = make_subplots(
    rows=1,
    cols=2,
    specs=[[{"type": "scene"}, {"type": "scene"}]],
    subplot_titles=("|u(v)| displacement", "Area change J"),
)

fig.add_trace(
    go.Mesh3d(
        x=verts[:, 0],
        y=verts[:, 1],
        z=verts[:, 2],
        i=faces[:, 0],
        j=faces[:, 1],
        k=faces[:, 2],
        intensity=displacement_norm,
        colorscale="Turbo",
        colorbar=dict(title="|u|"),
        name="Displacement",
        showscale=True,
        hoverinfo="skip",
    ),
    row=1,
    col=1,
)
fig.add_trace(
    go.Mesh3d(
        x=verts[:, 0],
        y=verts[:, 1],
        z=verts[:, 2],
        i=faces[:, 0],
        j=faces[:, 1],
        k=faces[:, 2],
        intensity=area_change,
        colorscale="Viridis",
        colorbar=dict(title="J"),
        name="Area change",
        showscale=True,
        hoverinfo="skip",
    ),
    row=1,
    col=2,
)
fig.update_layout(
    title=f"Local deformation metrics for {selected_latent_path.name}",
    scene=dict(aspectmode="data"),
    scene2=dict(aspectmode="data"),
    height=600,
    width=1200,
    margin=dict(l=0, r=0, t=60, b=0),
)
fig.show()

summary = pd.DataFrame(
    {
        "metric": ["|u|", "J", "κ"],
        "mean": [
            displacement_norm.mean(),
            area_change.mean(),
            stretch.mean(),
        ],
        "std": [
            displacement_norm.std(),
            area_change.std(),
            stretch.std(),
        ],
        "max": [
            displacement_norm.max(),
            area_change.max(),
            stretch.max(),
        ],
    }
).set_index("metric")
summary


Unnamed: 0_level_0,mean,std,max
metric,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
|u|,0.078864,0.038583,0.18544
J,1.088691,0.506687,3.710156
κ,1.289787,0.258461,5.227994


In [6]:
from analysis.latent_metrics import spectral_metrics

spec = spectral_metrics(
    verts=current_target_mesh.verts_packed().cpu().numpy(),
    faces=current_target_mesh.faces_packed().cpu().numpy(),
    baseline_verts=baseline_verts,
    baseline_faces=baseline_faces,
    k=15,
)

idx = np.arange(spec.eigenvalues.shape[0])
fig_spec = go.Figure()
fig_spec.add_trace(
    go.Scatter(x=idx, y=spec.baseline_eigenvalues, name="baseline", mode="lines+markers")
)
fig_spec.add_trace(
    go.Scatter(x=idx, y=spec.eigenvalues, name=current_latent_path.name, mode="lines+markers")
)
fig_spec.update_layout(
    title=f"Laplace–Beltrami spectrum • d_spec={spec.spectrum_distance:.3f}",
    xaxis_title="Eigen index",
    yaxis_title="λ",
    height=400,
)
fig_spec.show()

hks_diff = spec.hks - spec.baseline_hks
fig_hks = go.Figure(
    data=go.Heatmap(
        z=hks_diff,
        colorscale="RdBu",
        colorbar=dict(title="ΔHKS"),
    )
)
fig_hks.update_layout(
    title="Heat Kernel Signature differences",
    xaxis_title="Scale index",
    yaxis_title="Vertex index",
    height=500,
)
fig_hks.show()

pd.DataFrame(
    {
        "mean ΔHKS": [spec.hks_distance_per_vertex.mean()],
        "max ΔHKS": [spec.hks_distance_per_vertex.max()],
    },
    index=[current_latent_path.name],
)


Unnamed: 0,mean ΔHKS,max ΔHKS
cirrhotic_115_testing_latent.pt,2.23761,10.587255


In [7]:
from analysis.latent_metrics import finite_difference_latent_sensitivity

fd_indices = list(range(8))
fd_result = finite_difference_latent_sensitivity(
    decoder,
    template,
    baseline_latent,
    latent_indices=fd_indices,
    epsilon=1e-2,
    device=DEVICE,
    return_jacobians=True,
)

fd_global_scores = pd.Series(fd_result.global_scores, index=fd_indices).sort_values(ascending=False)
fd_global_scores.name = "S_i"
fd_global_scores_df = fd_global_scores.to_frame()
fd_global_scores_df

top_latent = int(fd_global_scores.index[0])
top_col = fd_indices.index(top_latent)
top_sens = fd_result.per_vertex_sensitivity[:, top_col]

sens_fig = go.Figure(
    data=go.Mesh3d(
        x=baseline_verts[:, 0],
        y=baseline_verts[:, 1],
        z=baseline_verts[:, 2],
        i=baseline_faces[:, 0],
        j=baseline_faces[:, 1],
        k=baseline_faces[:, 2],
        intensity=top_sens,
        colorscale="Plasma",
        colorbar=dict(title=f"s_v,{top_latent}"),
        name="Sensitivity",
        showscale=True,
        hoverinfo="skip",
    )
)
sens_fig.update_layout(
    title=f"Finite-difference sensitivity for latent {top_latent}",
    scene=dict(aspectmode="data"),
    height=600,
)
sens_fig.show()



__floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').



In [8]:
from analysis.latent_metrics import autodiff_vertex_jacobian_norms

jac_norms = autodiff_vertex_jacobian_norms(
    decoder,
    template,
    baseline_latent,
    num_probes=16,
    device=DEVICE,
)

jac_fig = go.Figure(
    data=go.Mesh3d(
        x=baseline_verts[:, 0],
        y=baseline_verts[:, 1],
        z=baseline_verts[:, 2],
        i=baseline_faces[:, 0],
        j=baseline_faces[:, 1],
        k=baseline_faces[:, 2],
        intensity=jac_norms,
        colorscale="Inferno",
        colorbar=dict(title="‖J_v‖_F"),
        name="Jacobian norm",
        showscale=True,
        hoverinfo="skip",
    )
)
jac_fig.update_layout(
    title="Autodiff vertex Jacobian norms",
    scene=dict(aspectmode="data"),
    height=600,
)
jac_fig.show()

pd.DataFrame(
    {
        "mean ‖J_v‖_F": [jac_norms.mean()],
        "max ‖J_v‖_F": [jac_norms.max()],
    },
    index=["baseline"],
)


Unnamed: 0,mean ‖J_v‖_F,max ‖J_v‖_F
baseline,1.224859,4.069701


In [9]:
from analysis.latent_metrics import pullback_metric

pull_indices = list(fd_global_scores.index[:5])
pull_metric, used_indices = pullback_metric(
    decoder,
    template,
    baseline_latent,
    area_weights=None,
    device=DEVICE,
    latent_indices=pull_indices,
    epsilon=1e-2,
)

eigvals, eigvecs = np.linalg.eigh(pull_metric)
pull_df = pd.DataFrame(
    {
        "latent_index": used_indices,
        "global_score": fd_global_scores.reindex(used_indices).values,
    }
).set_index("latent_index")
pull_df["pullback_eigenvector"] = list(eigvecs.T)
print("Pullback metric eigenvalues:", eigvals[::-1])
pull_df



__floordiv__ is deprecated, and its behavior will change in a future version of pytorch. It currently rounds toward 0 (like the 'trunc' function NOT 'floor'). This results in incorrect rounding for negative values. To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor').



Pullback metric eigenvalues: [37.44395319 18.02660021 16.1441309  11.76247555  7.70904898]


Unnamed: 0_level_0,global_score,pullback_eigenvector
latent_index,Unnamed: 1_level_1,Unnamed: 2_level_1
1,4.817982,"[0.3335493263603061, -0.07204348271789623, -0...."
6,4.327742,"[-0.3736074430110545, 0.08891056373114276, -0...."
5,4.088297,"[-0.4578853190858357, -0.871066990398017, 0.10..."
4,4.064396,"[0.384619679579175, -0.3362841901537749, -0.67..."
7,3.988786,"[-0.6257563919429053, 0.3392047905711937, -0.3..."
