# Latent Mesh Explorer

Interactively decode MeshDecoder latent vectors, tweak a single coefficient, and inspect the resulting mesh wireframe.


In [10]:
from pathlib import Path
import sys
from typing import Dict

import numpy as np
import torch
from torch.nn import Embedding
# from torch.serialization import add_safe_globals

import plotly.graph_objects as go
from plotly.subplots import make_subplots

import ipywidgets as widgets
from IPython.display import display
from pytorch3d.structures import Meshes
# import subdivide meshes 
from pytorch3d.ops import subdivide_meshes

PROJECT_ROOT = Path('/home/ralbe/DALS/mesh_autodecoder').resolve()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.append(str(PROJECT_ROOT))

from model.mesh_decoder import MeshDecoder

CHECKPOINT_PATH = Path('/home/ralbe/DALS/mesh_autodecoder/models/MeshDecoderTrainer_2025-11-06_12-00-26.ckpt')
LATENT_DIR = Path('/home/ralbe/DALS/mesh_autodecoder/inference_results/meshes_MeshDecoderTrainer_2025-11-06_12-00-26/latents')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# add_safe_globals([Embedding])


In [11]:
checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
hparams: Dict = checkpoint['hparams']

decoder = MeshDecoder(
    latent_features=hparams['latent_features'],
    steps=hparams['steps'],
    hidden_features=hparams['hidden_features'],
    subdivide=hparams['subdivide'],
    mode=hparams['decoder_mode'],
    norm=hparams['normalization'][0],
)
decoder.load_state_dict(checkpoint['decoder_state_dict'])
decoder.to(DEVICE)
decoder.eval()

template: Meshes = checkpoint['template'].to(DEVICE)
subdiv = subdivide_meshes.SubdivideMeshes()

template = subdiv(template).to(DEVICE)
latent_dim = hparams['latent_features']


print(f'Device: {DEVICE}')
print(f'Latent dimension: {latent_dim}')


Device: cpu
Latent dimension: 128


In [12]:
latent_cache: Dict[Path, torch.Tensor] = {}

def load_latent_tensor(path: Path) -> torch.Tensor:
    if path not in latent_cache:
        tensor = torch.load(path, map_location='cpu')
        if isinstance(tensor, dict):
            for key in ('latent', 'latent_vector', 'latent_vectors'):
                if key in tensor:
                    tensor = tensor[key]
                    break
            else:
                raise KeyError(f'No latent vector found in {path}')
        tensor = torch.as_tensor(tensor).float()
        if tensor.ndim == 2:
            tensor = tensor.squeeze(0)
        latent_cache[path] = tensor.clone()
    return latent_cache[path].clone()

def decode_latent(latent_1d: torch.Tensor) -> Meshes:
    latent_batch = latent_1d.unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        decoded = decoder(template.clone(), latent_batch)[-1]
    return decoded.cpu()

def mesh_bounds(mesh: Meshes) -> np.ndarray:
    verts = mesh.verts_packed()
    return verts.min(dim=0).values.cpu().numpy(), verts.max(dim=0).values.cpu().numpy()

def mesh_to_wireframe_trace(mesh: Meshes, color: str, name: str) -> go.Scatter3d:
    mesh_cpu = mesh.cpu()
    verts = mesh_cpu.verts_packed().numpy()
    edges = mesh_cpu.edges_packed().numpy()
    coords = np.empty((3 * len(edges), 3), dtype=np.float32)
    coords[0::3] = verts[edges[:, 0]]
    coords[1::3] = verts[edges[:, 1]]
    coords[2::3] = np.nan
    return go.Scatter3d(
        x=coords[:, 0],
        y=coords[:, 1],
        z=coords[:, 2],
        mode='lines',
        name=name,
        line=dict(color=color, width=2),
        hoverinfo='skip',
    )

def make_comparison_figure(
    original: Meshes,
    modified: Meshes,
    latent_index: int,
    new_value: float,
    original_value: float,
) -> go.Figure:
    o_min, o_max = mesh_bounds(original)
    m_min, m_max = mesh_bounds(modified)

    def axis_ranges(mins: np.ndarray, maxs: np.ndarray):
        center = 0.5 * (mins + maxs)
        span = max(maxs - mins)
        half = max(span / 2.0, 1e-3)
        return [[center[i] - half, center[i] + half] for i in range(3)]

    original_ranges = axis_ranges(o_min, o_max)
    combined_ranges = axis_ranges(np.minimum(o_min, m_min), np.maximum(o_max, m_max))

    fig = make_subplots(
        rows=1,
        cols=2,
        specs=[[{'type': 'scene'}, {'type': 'scene'}]],
        subplot_titles=(
            f'Original (index {latent_index} = {original_value:+.4f})',
            f'Overlay (index {latent_index} → {new_value:+.4f})',
        ),
    )

    fig.add_trace(mesh_to_wireframe_trace(original, '#636EFA', 'Original'), row=1, col=1)

    baseline_trace = mesh_to_wireframe_trace(original, '#9CA3FF', 'Original baseline')
    baseline_trace.line['width'] = 1
    fig.add_trace(baseline_trace, row=1, col=2)
    fig.add_trace(mesh_to_wireframe_trace(modified, '#EF553B', 'Modified'), row=1, col=2)

    scene_original = dict(
        xaxis=dict(visible=False, range=combined_ranges[0]),
        yaxis=dict(visible=False, range=combined_ranges[1]),
        zaxis=dict(visible=False, range=combined_ranges[2]),
        aspectmode='cube',
    )
    scene_overlay = dict(
        xaxis=dict(visible=False, range=original_ranges[0]),
        yaxis=dict(visible=False, range=original_ranges[1]),
        zaxis=dict(visible=False, range=original_ranges[2]),
        aspectmode='cube',
    )
    fig.update_layout(
        height=600,
        width=1100,
        margin=dict(l=0, r=0, t=60, b=0),
        scene=scene_original,
        scene2=scene_overlay,
        showlegend=False,
    )
    return fig


In [13]:
latent_files = sorted(LATENT_DIR.glob('*_latent.pt'))
if not latent_files:
    raise FileNotFoundError(f'No latent tensors found in {LATENT_DIR}')

latent_dropdown = widgets.Dropdown(
    options=[(path.name, str(path)) for path in latent_files],
    value=str(latent_files[0]),
    description='Latent file:',
    layout=widgets.Layout(width='60%'),
)

index_text = widgets.BoundedIntText(
    min=0,
    max=latent_dim - 1,
    value=0,
    step=1,
    description='Index:',
    layout=widgets.Layout(width='120px'),
    style={'description_width': 'initial'},
)

value_slider = widgets.FloatSlider(
    min=-0.4,
    max=0.4,
    step=0.01,
    value=0.0,
    description='New value:',
    continuous_update=False,
)

info_label = widgets.HTML()


def render(latent_path: str, latent_index: int, slider_value: float):
    latent_tensor = load_latent_tensor(Path(latent_path))
    # Clamp latent_index to allowed range just in case
    latent_index = int(np.clip(latent_index, 0, latent_tensor.shape[0] - 1))
    original_value = float(latent_tensor[latent_index])
    info_label.value = f'<b>Original value:</b> {original_value:+.4f}'

    modified_tensor = latent_tensor.clone()
    modified_tensor[latent_index] = slider_value

    original_mesh = decode_latent(latent_tensor)
    modified_mesh = decode_latent(modified_tensor)

    fig = make_comparison_figure(
        original_mesh,
        modified_mesh,
        latent_index=latent_index,
        new_value=slider_value,
        original_value=original_value,
    )
    display(fig)


def handle_latent_change(change):
    latent_tensor = load_latent_tensor(Path(change['new']))
    index_text.max = latent_tensor.shape[0] - 1
    index_text.value = 0
    base_value = float(latent_tensor[0])
    clipped = float(np.clip(base_value, value_slider.min, value_slider.max))
    if abs(value_slider.value - clipped) > 1e-6:
        value_slider.value = clipped
    else:
        info_label.value = f'<b>Original value:</b> {base_value:+.4f}'


def handle_index_change(change):
    latent_tensor = load_latent_tensor(Path(latent_dropdown.value))
    new_index = change['new']
    # Clamp just in case
    if new_index < 0:
        new_index = 0
    elif new_index > latent_tensor.shape[0] - 1:
        new_index = latent_tensor.shape[0] - 1
    base_value = float(latent_tensor[new_index])
    info_label.value = f'<b>Original value:</b> {base_value:+.4f}'
    clipped = float(np.clip(base_value, value_slider.min, value_slider.max))
    if abs(value_slider.value - clipped) > 1e-6:
        value_slider.value = clipped

latent_dropdown.observe(handle_latent_change, names='value')
index_text.observe(handle_index_change, names='value')

interactive_plot = widgets.interactive_output(
    render,
    {
        'latent_path': latent_dropdown,
        'latent_index': index_text,
        'slider_value': value_slider,
    },
)

controls = widgets.VBox([
    latent_dropdown,
    widgets.HBox([index_text, value_slider]),
    info_label,
])

handle_latent_change({'new': latent_dropdown.value})
display(controls, interactive_plot)


VBox(children=(Dropdown(description='Latent file:', layout=Layout(width='60%'), options=(('cirrhotic_115_testi…

Output()