### Import packages for the following code:


In [None]:

import sys
sys.path.append('/home/ralbe/DALS/mesh_autodecoder')

import os
import time

import numpy as np
import pandas as pd
import torch
import trimesh

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import plotly.graph_objects as go

from pytorch3d.io import load_objs_as_meshes
from pytorch3d.ops import sample_points_from_meshes, SubdivideMeshes
from pytorch3d.loss import chamfer_distance

from model.mesh_decoder import MeshDecoder
from model.loss import mesh_bl_quality_loss
from util.metrics import point_metrics, self_intersections

from scipy.spatial import cKDTree  # Compute distance from each predicted (decoded) vertex to closest GT (target) vertex



### Load checkpoint and extract info from it 

In [None]:
model_checkpoint_path = '/home/ralbe/DALS/mesh_autodecoder/models/MeshDecoderTrainer_2025-11-06_12-00-26.ckpt'


checkpoint = torch.load(model_checkpoint_path,map_location='cpu')

keys = list(checkpoint.keys())

for key in keys:
    
    if key!='decoder_state_dict':
        print(key)
        
hparams=checkpoint['hparams']
latent_vectors = checkpoint['latent_vectors']
best_epoch = checkpoint['best_epoch']
best_loss = checkpoint['best_loss']
train_data_path = checkpoint['train_data_path']
val_data_path = checkpoint['val_data_path']
train_file_names = checkpoint['train_filenames']
latent_features = checkpoint['latent_features']
decoder_mode = checkpoint['decoder_mode']
template = checkpoint['template']




# AND NOW PRINT THEM ALL

for key in hparams.keys():
    print(f"{key}: {hparams[key]}")
print(f"latent_vectors: {latent_vectors}")
print(f"best_epoch: {best_epoch}")
print(f"best_loss: {best_loss}")
print(f"train_data_path: {train_data_path}")
print(f"val_data_path: {val_data_path}")


### Extract latent vectors distribution information and plot it 

In [None]:
# Assuming you already have latent_vectors as a PyTorch parameter
latent_vectors_np = latent_vectors.weight.detach().cpu().numpy()
num_vectors, latent_dim = latent_vectors_np.shape

# Prepare data for 3D scatter
k_indices = np.repeat(np.arange(latent_dim), num_vectors)      # x-axis: latent dimension index
values = latent_vectors_np.T.flatten()                         # y-axis: latent value
vector_indices = np.tile(np.arange(num_vectors), latent_dim)   # z-axis: vector index

# Calculate mean and std for each latent dimension
means = latent_vectors_np.mean(axis=0)  # shape: (latent_dim,)
stds = latent_vectors_np.std(axis=0)    # shape: (latent_dim,)

# Create list of figure traces
scatter_trace = go.Scatter3d(
    x=k_indices,
    y=values,
    z=vector_indices,
    mode='markers',
    marker=dict(
        size=3,
        color=values,
        colorscale='Viridis',
        opacity=0.6,
        colorbar=dict(title='Latent Value')
    ),
    name='Values'
)

# NEW: Show mean ± std as vertical error bars instead of the shaded band

mean_trace = go.Scatter3d(
    x=np.arange(latent_dim),
    y=means,
    z=np.full(latent_dim, num_vectors // 2),
    mode='markers+lines',
    marker=dict(size=6, color='red'),
    line=dict(color='red', width=3),
    name='Mean'
)

# Add error bars for mean ± std as separate 'bar' traces at each x (dim k), at mid z
error_bar_traces = []
for k in range(latent_dim):
    error_bar_traces.append(
        go.Scatter3d(
            x=[k, k],
            y=[means[k] - stds[k], means[k] + stds[k]],
            z=[num_vectors // 2, num_vectors // 2],
            mode='lines',
            line=dict(color='orange', width=4),
            showlegend=(k==0),
            name='Mean ± Std' if k == 0 else None,
            hoverinfo='none'
        )
    )

# Create interactive 3D scatter plot with mean and std error bars
fig = go.Figure(data=[scatter_trace, mean_trace] + error_bar_traces)

fig.update_layout(
    title='Interactive 3D Scatter: Latent Vector Values per Dimension (Mean ± Std as Error Bars)',
    scene=dict(
        xaxis_title='Latent Vector Dimension k',
        yaxis_title='Latent Value',
        zaxis_title='Vector Index'
    ),
    margin=dict(l=0, r=0, b=0, t=40),
    legend=dict(x=0.01, y=0.99),
    height=1200,
    width=1200
)

fig.show()

### Check out for a specific training mesh index, the saved z associated with that mesh. 
NOTE: this result is just the saved representation WITHOUT test-time inference. 
NOTE 2 : For good reconstruction metrics, more than 10k metric_samples are advised
TO-DO: Add a test-time optimization for the training meshes as well to have more accurate meshes

In [None]:


# --- Configuration ---
latent_index = 10  # choose which latent code to decode
output_dir = "notebook_decoded_mesh"
metric_samples = 20_000

os.makedirs(output_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Instantiate decoder and template ---
hparams = checkpoint["hparams"]
decoder = MeshDecoder(
    hparams["latent_features"],
    hparams["steps"],
    hparams["hidden_features"],
    hparams["subdivide"],
    mode=hparams["decoder_mode"],
    norm=hparams["normalization"][0],
).to(device).eval()
decoder.load_state_dict(checkpoint["decoder_state_dict"])

template = checkpoint["template"].to(device)
template = SubdivideMeshes()(template)

latent_module = checkpoint["latent_vectors"]
if isinstance(latent_module, torch.nn.Embedding):
    latent_tensor = latent_module.weight.detach()
elif isinstance(latent_module, torch.nn.Parameter):
    latent_tensor = latent_module.detach()
else:
    latent_tensor = torch.as_tensor(latent_module)
latent_tensor = latent_tensor.to(device)
print(f"Latent tensor shape: {latent_tensor.shape}")

latent_vector = latent_tensor[latent_index].unsqueeze(0)
# latent_vector[0,0]=0.0
print(latent_vector.shape)
print(f"Decoding latent index {latent_index}")

# --- Decode mesh ---
decode_start = time.time()
with torch.no_grad():
    decoded_mesh = decoder(template.clone(), latent_vector)[-1]
decode_time = time.time() - decode_start
print(f"Decode time: {decode_time:.3f}s")

# --- Save decoded mesh ---
mesh_filename = f"latent_{latent_index:03d}.obj"
mesh_path = os.path.join(output_dir, mesh_filename)
decoded_trimesh = trimesh.Trimesh(
    vertices=decoded_mesh.verts_packed().cpu().numpy(),
    faces=decoded_mesh.faces_packed().cpu().numpy(),
    process=False,
)
decoded_trimesh.export(mesh_path)
print(f"Saved decoded mesh to {mesh_path}")

# --- Load target mesh for metrics ---
if latent_index >= len(train_file_names):
    raise IndexError(
        f"latent index {latent_index} out of range for available training meshes (0-{len(train_file_names)-1})"
    )

target_mesh_path = os.path.join(train_data_path, train_file_names[latent_index])
print(f"Using target mesh: {target_mesh_path}")

target_mesh = load_objs_as_meshes([target_mesh_path], device=device)

# --- Metric computation ---
with torch.no_grad():
    pred_samples = sample_points_from_meshes(decoded_mesh, metric_samples)
    true_samples = sample_points_from_meshes(target_mesh, metric_samples)
    chamfer_val = chamfer_distance(true_samples, pred_samples)[0] * 10000
    metric_dict = point_metrics(true_samples, pred_samples, [0.01, 0.02])
    bl_quality = (1.0 - mesh_bl_quality_loss(decoded_mesh)).item()

decoded_mesh_cpu = decoded_mesh.cpu()
ints_tensor, _ = self_intersections(decoded_mesh_cpu)
faces_count = len(decoded_mesh_cpu.faces_packed())
ints_percent = 100.0 * float(ints_tensor[0]) / max(faces_count, 1)

summary = {
    "ChamferL2 x 10000_mean": chamfer_val.item(),
    "ChamferL2 x 10000_std": 0.0,
    "BL quality_mean": bl_quality,
    "BL quality_std": 0.0,
    "No. ints._mean": ints_percent,
    "No. ints._std": 0.0,
    "Precision@0.01_mean": metric_dict["Precision@0.01"].item(),
    "Precision@0.01_std": 0.0,
    "Recall@0.01_mean": metric_dict["Recall@0.01"].item(),
    "Recall@0.01_std": 0.0,
    "F1@0.01_mean": metric_dict["F1@0.01"].item(),
    "F1@0.01_std": 0.0,
    "Precision@0.02_mean": metric_dict["Precision@0.02"].item(),
    "Precision@0.02_std": 0.0,
    "Recall@0.02_mean": metric_dict["Recall@0.02"].item(),
    "Recall@0.02_std": 0.0,
    "F1@0.02_mean": metric_dict["F1@0.02"].item(),
    "F1@0.02_std": 0.0,
    "Search_mean": decode_time,
    "Search_std": 0.0,
    "Total_mean": decode_time,
    "Total_std": 0.0,
    "num_test_samples": 1,
}

metrics_df = pd.DataFrame([summary])
metrics_path = os.path.join(output_dir, f"latent_{latent_index:03d}_metrics.csv")
metrics_df.to_csv(metrics_path, index=False)
print(f"Metrics saved to {metrics_path}")
display(metrics_df.T.rename(columns={0: "value"}))


### 

In [None]:
visualize = True 
if visualize:


    def mesh_wire_edges(vertices, faces):
        # Returns unique undirected edges as (N_edges, 2) array for wireframe plotting
        edges = np.concatenate([
            faces[:, [0, 1]],
            faces[:, [1, 2]],
            faces[:, [2, 0]]
        ], axis=0)
        # Unique undirected edge set
        edges = np.sort(edges, axis=1)
        edges = np.unique(edges, axis=0)
        return edges

    decoded_vertices = decoded_mesh.verts_packed().cpu().numpy()
    decoded_faces = decoded_mesh.faces_packed().cpu().numpy()
    decoded_edges = mesh_wire_edges(decoded_vertices, decoded_faces)

    target_vertices = target_mesh.verts_packed().cpu().numpy()
    target_faces = target_mesh.faces_packed().cpu().numpy()
    target_edges = mesh_wire_edges(target_vertices, target_faces)


    fig = go.Figure()

    # Wireframe for decoded mesh (all in one trace for toggle)
    decoded_x, decoded_y, decoded_z = [], [], []
    for edge in decoded_edges:
        xs = decoded_vertices[edge, 0]
        ys = decoded_vertices[edge, 1]
        zs = decoded_vertices[edge, 2]
        # Add start and end point and then a None to break segments for plotly
        decoded_x.extend([xs[0], xs[1], None])
        decoded_y.extend([ys[0], ys[1], None])
        decoded_z.extend([zs[0], zs[1], None])

    fig.add_trace(go.Scatter3d(
        x=decoded_x, y=decoded_y, z=decoded_z,
        mode="lines",
        line=dict(color="royalblue", width=2),
        name="Decoded",
        showlegend=True,
        legendgroup="Decoded"
    ))

    # Wireframe for target mesh (all in one trace, solid lines as requested)
    target_x, target_y, target_z = [], [], []
    for edge in target_edges:
        xs = target_vertices[edge, 0]
        ys = target_vertices[edge, 1]
        zs = target_vertices[edge, 2]
        target_x.extend([xs[0], xs[1], None])
        target_y.extend([ys[0], ys[1], None])
        target_z.extend([zs[0], zs[1], None])

    fig.add_trace(go.Scatter3d(
        x=target_x, y=target_y, z=target_z,
        mode="lines",
        line=dict(color="orange", width=2),  # now solid lines for target!
        name="Target",
        showlegend=True,
        legendgroup="Target"
    ))

    fig.update_layout(
        title=f"Decoded vs Target Mesh Wireframes (latent {latent_index})",
        scene=dict(
            xaxis_title="X", 
            yaxis_title="Y", 
            zaxis_title="Z"
        ),
        height=900,  # Taller figure for better vertical viewing
        width=1300,
        legend=dict(
            x=0.01, y=0.99,
            itemsizing="constant",
            traceorder="normal",
            itemclick="toggle",             # Single click to toggle on/off
            itemdoubleclick="toggleothers"
        )
    )
    fig.show()

    tree_gt = cKDTree(target_vertices)
    dists, _ = tree_gt.query(decoded_vertices, k=1)  # (num_decoded_verts,)

    # Map the distances to each vertex, get RGB color for each vertex with Viridis colormap
    import matplotlib
    viridis = matplotlib.cm.get_cmap("viridis")
    norm = matplotlib.colors.Normalize(vmin=dists.min(), vmax=dists.max())
    vertex_rgb = viridis(norm(dists))[:, :3]  # (N,3), drop alpha

    import plotly.graph_objs as go

    fig = go.Figure()

    # We want each *edge* to be colored according to the color of its start (or averaged endpoints)
    # We'll use the mean of the two vertex colors for each edge

    for idx, edge in enumerate(decoded_edges):
        i0, i1 = edge
        x = [decoded_vertices[i0, 0], decoded_vertices[i1, 0]]
        y = [decoded_vertices[i0, 1], decoded_vertices[i1, 1]]
        z = [decoded_vertices[i0, 2], decoded_vertices[i1, 2]]
        color_rgb = (vertex_rgb[i0] + vertex_rgb[i1]) / 2.0  # (3,)
        hex_color = matplotlib.colors.to_hex(color_rgb)
        fig.add_trace(go.Scatter3d(
            x=x, y=y, z=z,
            mode="lines",
            line=dict(color=hex_color, width=2),
            showlegend=False,
            hoverinfo="skip"
        ))

    # To add a colorbar, add an invisible scatter object with the color scale set
    dummy_scatter = go.Scatter3d(
        x=[None], y=[None], z=[None],
        mode="markers",
        marker=dict(
            size=0.1,
            color=np.linspace(dists.min(), dists.max(), 100),  # dummy
            colorscale="Viridis",
            colorbar=dict(
                title='Dist. to Closest<br>Target Vertex'
            ),
            showscale=True
        ),
        hoverinfo="none",
        showlegend=False
    )
    fig.add_trace(dummy_scatter)

    fig.update_layout(
        title=f"Predicted Mesh Wireframe, colored by dist. to closest GT vertex (latent {latent_index})",
        scene=dict(
            xaxis_title="X", 
            yaxis_title="Y", 
            zaxis_title="Z"
        ),
        height=900,  # Taller figure for better vertical viewing
        width=1300
    )
    fig.show()

### Similar code as before, but for a test mesh, with the inference optimized mesh. Here we see that the test-time optimization really does enhance reconstruction metrics, which is apparent from the ~99.something % metrics

In [None]:


# --- Configuration ---
patient_ID = 72  # choose which latent code to decode
disease_status = 'healthy' # Choose between "healthy" and "cirrhotic"
output_dir = "notebook_decoded_mesh"
metric_samples = 12_500

os.makedirs(output_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- Instantiate decoder and template ---
hparams = checkpoint["hparams"]
decoder = MeshDecoder(
    hparams["latent_features"],
    hparams["steps"],
    hparams["hidden_features"],
    hparams["subdivide"],
    mode=hparams["decoder_mode"],
    norm=hparams["normalization"][0],
).to(device).eval()
decoder.load_state_dict(checkpoint["decoder_state_dict"])

template = checkpoint["template"].to(device)
# Subdivide again for a 4-times subdivided icosahedron template
template = SubdivideMeshes()(template)

# latent_module = checkpoint["latent_vectors"]
# if isinstance(latent_module, torch.nn.Embedding):
#     latent_tensor = latent_module.weight.detach()
# elif isinstance(latent_module, torch.nn.Parameter):
#     latent_tensor = latent_module.detach()
# else:
#     latent_tensor = torch.as_tensor(latent_module)
# latent_tensor = latent_tensor.to(device)
# print(f"Latent tensor shape: {latent_tensor.shape}")

# latent_vector = latent_tensor[latent_index].unsqueeze(0)
# latent_vector[0,0]=0.0
# print(f"Decoding latent index {latent_index}")
base_dir ='/home/ralbe/DALS/mesh_autodecoder/inference_results/meshes_MeshDecoderTrainer_2025-11-06_12-00-26/'
latent_vector_file = base_dir +f'latents/{disease_status}_{patient_ID}_testing_latent.pt'
target_mesh_file = base_dir +f'{disease_status}_{patient_ID}_testing_target.obj'
decoded_mesh_file = base_dir +f'{disease_status}_{patient_ID}_testing_optimized.obj'
latent_vector = torch.load(latent_vector_file,map_location=device)
target_mesh = load_objs_as_meshes([target_mesh_file], device=device)
decoded_mesh_0 = load_objs_as_meshes([decoded_mesh_file], device=device)




# --- Decode mesh ---
decode_start = time.time()
with torch.no_grad():
    decoded_mesh = decoder(template.clone(), latent_vector)[-1]
decode_time = time.time() - decode_start
print(f"Decode time: {decode_time:.3f}s")

# --- Save decoded mesh ---
mesh_filename = f"latent_{latent_index:03d}.obj"
mesh_path = os.path.join(output_dir, mesh_filename)
decoded_trimesh = trimesh.Trimesh(
    vertices=decoded_mesh.verts_packed().cpu().numpy(),
    faces=decoded_mesh.faces_packed().cpu().numpy(),
    process=False,
)
decoded_trimesh.export(mesh_path)
print(f"Saved decoded mesh to {mesh_path}")

# --- Load target mesh for metrics ---
if latent_index >= len(train_file_names):
    raise IndexError(
        f"latent index {latent_index} out of range for available training meshes (0-{len(train_file_names)-1})"
    )



# --- Metric computation ---
with torch.no_grad():
    pred_samples = sample_points_from_meshes(decoded_mesh, metric_samples)
    true_samples = sample_points_from_meshes(target_mesh, metric_samples)
    chamfer_val = chamfer_distance(true_samples, pred_samples)[0] * 10000
    metric_dict = point_metrics(true_samples, pred_samples, [0.01, 0.02])
    bl_quality = (1.0 - mesh_bl_quality_loss(decoded_mesh)).item()

decoded_mesh_cpu = decoded_mesh.cpu()
ints_tensor, _ = self_intersections(decoded_mesh_cpu)
faces_count = len(decoded_mesh_cpu.faces_packed())
ints_percent = 100.0 * float(ints_tensor[0]) / max(faces_count, 1)

summary = {
    "ChamferL2 x 10000_mean": chamfer_val.item(),
    "ChamferL2 x 10000_std": 0.0,
    "BL quality_mean": bl_quality,
    "BL quality_std": 0.0,
    "No. ints._mean": ints_percent,
    "No. ints._std": 0.0,
    "Precision@0.01_mean": metric_dict["Precision@0.01"].item(),
    "Precision@0.01_std": 0.0,
    "Recall@0.01_mean": metric_dict["Recall@0.01"].item(),
    "Recall@0.01_std": 0.0,
    "F1@0.01_mean": metric_dict["F1@0.01"].item(),
    "F1@0.01_std": 0.0,
    "Precision@0.02_mean": metric_dict["Precision@0.02"].item(),
    "Precision@0.02_std": 0.0,
    "Recall@0.02_mean": metric_dict["Recall@0.02"].item(),
    "Recall@0.02_std": 0.0,
    "F1@0.02_mean": metric_dict["F1@0.02"].item(),
    "F1@0.02_std": 0.0,
    "Search_mean": decode_time,
    "Search_std": 0.0,
    "Total_mean": decode_time,
    "Total_std": 0.0,
    "num_test_samples": 1,
}

metrics_df = pd.DataFrame([summary])
metrics_path = os.path.join(output_dir, f"latent_{latent_index:03d}_metrics.csv")
metrics_df.to_csv(metrics_path, index=False)
print(f"Metrics saved to {metrics_path}")
display(metrics_df.T.rename(columns={0: "value"}))


In [None]:
visualize = False 

if visualize:

    # --- Visualize decoded vs target mesh as wireframes ---
    import numpy as np

    def mesh_wire_edges(vertices, faces):
        # Returns unique undirected edges as (N_edges, 2) array for wireframe plotting
        edges = np.concatenate([
            faces[:, [0, 1]],
            faces[:, [1, 2]],
            faces[:, [2, 0]]
        ], axis=0)
        # Unique undirected edge set
        edges = np.sort(edges, axis=1)
        edges = np.unique(edges, axis=0)
        return edges

    decoded_vertices = decoded_mesh.verts_packed().cpu().numpy()
    decoded_faces = decoded_mesh.faces_packed().cpu().numpy()
    decoded_edges = mesh_wire_edges(decoded_vertices, decoded_faces)

    target_vertices = target_mesh.verts_packed().cpu().numpy()
    target_faces = target_mesh.faces_packed().cpu().numpy()
    target_edges = mesh_wire_edges(target_vertices, target_faces)


    fig = go.Figure()

    # Wireframe for decoded mesh (all in one trace for toggle)
    decoded_x, decoded_y, decoded_z = [], [], []
    for edge in decoded_edges:
        xs = decoded_vertices[edge, 0]
        ys = decoded_vertices[edge, 1]
        zs = decoded_vertices[edge, 2]
        # Add start and end point and then a None to break segments for plotly
        decoded_x.extend([xs[0], xs[1], None])
        decoded_y.extend([ys[0], ys[1], None])
        decoded_z.extend([zs[0], zs[1], None])

    fig.add_trace(go.Scatter3d(
        x=decoded_x, y=decoded_y, z=decoded_z,
        mode="lines",
        line=dict(color="royalblue", width=2),
        name="Decoded",
        showlegend=True,
        legendgroup="Decoded"
    ))

    # Wireframe for target mesh (all in one trace, solid lines as requested)
    target_x, target_y, target_z = [], [], []
    for edge in target_edges:
        xs = target_vertices[edge, 0]
        ys = target_vertices[edge, 1]
        zs = target_vertices[edge, 2]
        target_x.extend([xs[0], xs[1], None])
        target_y.extend([ys[0], ys[1], None])
        target_z.extend([zs[0], zs[1], None])

    fig.add_trace(go.Scatter3d(
        x=target_x, y=target_y, z=target_z,
        mode="lines",
        line=dict(color="orange", width=2),  # now solid lines for target!
        name="Target",
        showlegend=True,
        legendgroup="Target"
    ))

    fig.update_layout(
        title=f"Decoded vs Target Mesh Wireframes (latent {latent_index})",
        scene=dict(
            xaxis_title="X", 
            yaxis_title="Y", 
            zaxis_title="Z"
        ),
        height=900,  # Taller figure for better vertical viewing
        width=1300,
        legend=dict(
            x=0.01, y=0.99,
            itemsizing="constant",
            traceorder="normal",
            itemclick="toggle",             # Single click to toggle on/off
            itemdoubleclick="toggleothers"
        )
    )
    fig.show()

    tree_gt = cKDTree(target_vertices)
    dists, _ = tree_gt.query(decoded_vertices, k=1)  # (num_decoded_verts,)

    # Map the distances to each vertex, get RGB color for each vertex with Viridis colormap
    import matplotlib
    viridis = matplotlib.cm.get_cmap("viridis")
    norm = matplotlib.colors.Normalize(vmin=dists.min(), vmax=dists.max())
    vertex_rgb = viridis(norm(dists))[:, :3]  # (N,3), drop alpha

    import plotly.graph_objs as go

    fig = go.Figure()

    # We want each *edge* to be colored according to the color of its start (or averaged endpoints)
    # We'll use the mean of the two vertex colors for each edge

    for idx, edge in enumerate(decoded_edges):
        i0, i1 = edge
        x = [decoded_vertices[i0, 0], decoded_vertices[i1, 0]]
        y = [decoded_vertices[i0, 1], decoded_vertices[i1, 1]]
        z = [decoded_vertices[i0, 2], decoded_vertices[i1, 2]]
        color_rgb = (vertex_rgb[i0] + vertex_rgb[i1]) / 2.0  # (3,)
        hex_color = matplotlib.colors.to_hex(color_rgb)
        fig.add_trace(go.Scatter3d(
            x=x, y=y, z=z,
            mode="lines",
            line=dict(color=hex_color, width=2),
            showlegend=False,
            hoverinfo="skip"
        ))

    # To add a colorbar, add an invisible scatter object with the color scale set
    dummy_scatter = go.Scatter3d(
        x=[None], y=[None], z=[None],
        mode="markers",
        marker=dict(
            size=0.1,
            color=np.linspace(dists.min(), dists.max(), 100),  # dummy
            colorscale="Viridis",
            colorbar=dict(
                title='Dist. to Closest<br>Target Vertex'
            ),
            showscale=True
        ),
        hoverinfo="none",
        showlegend=False
    )
    fig.add_trace(dummy_scatter)

    fig.update_layout(
        title=f"Predicted Mesh Wireframe, colored by dist. to closest GT vertex (latent {latent_index})",
        scene=dict(
            xaxis_title="X", 
            yaxis_title="Y", 
            zaxis_title="Z"
        ),
        height=900,  # Taller figure for better vertical viewing
        width=1300
    )
    fig.show()

### Now

In [None]:
print(latent_vectors_np.shape)  # e.g. (345, 128)

import plotly.graph_objects as go

# --- Case A: Deform using the average latent vector ---
avg_latent_vector = latent_vectors_np.mean(axis=0)
print("avg latent shape", avg_latent_vector.shape)  # e.g. (128,)
avg_latent_tensor = torch.tensor(avg_latent_vector, dtype=torch.float32).unsqueeze(0)
decoded_avg_mesh = decoder(template.clone(), avg_latent_tensor)[-1]
decoded_avg_mesh_cpu = decoded_avg_mesh.cpu()
avg_vertices = decoded_avg_mesh_cpu.verts_packed().detach().numpy()
avg_faces = decoded_avg_mesh_cpu.faces_packed().detach().numpy()

# --- Case B: Set elements 20:40 of the avg latent vector to zero, then decode ---
lat_mod = avg_latent_vector.copy()
lat_mod[20:40] = 0.4
lat_mod_tensor = torch.tensor(lat_mod, dtype=torch.float32).unsqueeze(0)
decoded_lat_mod_mesh = decoder(template.clone(), lat_mod_tensor)[-1]
decoded_lat_mod_mesh_cpu = decoded_lat_mod_mesh.cpu()
mod_vertices = decoded_lat_mod_mesh_cpu.verts_packed().detach().numpy()
mod_faces = decoded_lat_mod_mesh_cpu.faces_packed().detach().numpy()

def plot_wireframe(vertices, faces, title_str="Wireframe Mesh", color="blue"):
    # For each face, plot each edge as a line
    lines = []
    for tri in faces:
        pts = vertices[tri]
        # Add edges: (v0,v1), (v1,v2), (v2,v0)
        edges = [
            (pts[0], pts[1]),
            (pts[1], pts[2]),
            (pts[2], pts[0]),
        ]
        for (p, q) in edges:
            lines.append(go.Scatter3d(
                x=[p[0], q[0]],
                y=[p[1], q[1]],
                z=[p[2], q[2]],
                mode='lines',
                line=dict(color=color, width=2),
                hoverinfo='skip',
                showlegend=False
            ))
    fig = go.Figure(data=lines)
    fig.update_layout(
        scene=dict(
            xaxis_title="X",
            yaxis_title="Y",
            zaxis_title="Z"
        ),
        title=title_str,
        height=600,
        width=900,
    )
    fig.show()

# Plot both meshes
plot_wireframe(avg_vertices, avg_faces, title_str="Wireframe: avg latent vector", color="blue")
plot_wireframe(mod_vertices, mod_faces, title_str="Wireframe: avg latent with [20:40]=0", color="red")




In [None]:
mesh_avg = trimesh.Trimesh(
    vertices=avg_vertices,
    faces=avg_faces,
    process=False,
)
mesh_mod = trimesh.Trimesh(
    vertices=mod_vertices,
    faces=mod_faces,
    process=False,
)
mesh_avg.show()


In [None]:

mesh_mod.show()