<a href="https://colab.research.google.com/github/aubricot/nsm/blob/main/demos/shape_completion_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Squamate Vertebra Shape Completion Demo**   
*Last edited 19 Jan 2026*


This notebook demonstrates shape completion for partial vertebrae (modern and fossil) using a trained Neural Shape Model (NSM; Gatti et al. 2025, Park et al. 2019). It can be run fully in demo mode without connecting to your Google Drive. Adjust parameters using form fields and make sure your runtime environment is set to run on GPU. Full repository code is available at [aubricot/nsm on GitHub](https://github.com/aubricot/nsm).

Modern vertebra meshes are derived from micro-CT data produced by the oVert Initiative (Blackburn et al. 2024). Fossil vertebra were downloaded from MorphoSource ([UF546657](https://doi.org/10.17602/M2/M600663); [UF271967](https://n2t.net/ark:/87602/m4/M69199)). All vertebrae were aligned and scaled using ATLAS before training (Porto et al. 2026).


**References**
* Blackburn et al. 2024, BioScience. https://doi.org/10.1093/biosci/biad120
* Gatti et al. 2025, IEEE TMI. https://doi.org/10.1109/tmi.2024.3485613
* Park et al. 2019, CVPR. https://doi.org/10.48550/arXiv.1901.05103
* Porto et al. 2026, in prep. https://github.com/agporto/ATLAS

## 1. Installs & Imports

In [None]:
#@title Check GPU and CUDA info - make sure Colab Runtime set to GPU
from psutil import virtual_memory

# Check GPU and CUDA
!nvcc --version
gpu = !nvidia-smi
gpu = '\n'.join(gpu)
print('\033[91mNot connected to a GPU\033[0m' if 'failed' in gpu else gpu)

# Check RAM
ram = virtual_memory().total / 1e9
print(f'\033[92mYour runtime has {ram:.1f} GB of RAM\033[0m\n')

In [None]:
#@title Choose where to save results

# Use dropdown menu on right
save = "in Colab runtime (files deleted after each session)" #@param ["in my Google Drive", "in Colab runtime (files deleted after each session)"]

# Mount google drive to export image tagging file(s)
if 'Google Drive' in save:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)

# Type in the path to your project wd in form field on right
base_wd = "/content/drive/MyDrive" # @param ["/content/drive/MyDrive/nsm"] {"allow-input":true}
wd = base_wd + "/nsm"
print(f"\033[92mWorking directory set to: \n{wd}\033[0m")

In [None]:
#@title Set up environment and install NSM
import os
import sys

# Install PyTorch with CUDA support (Colab typically has CUDA 11.8 or 12.x)
print("\033[92mSetting up environment...\033[0m")
print("\n\033[33m-----This will take a few minutes----\033[0m")
!pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124

# Install other dependencies
!pip install pyvista mskt open3d scikit-learn matplotlib pandas numpy scipy
!pip install ipywidgets
!pip install nibabel scikit-image opencv-python open3d

# Clone NSM repository
if not os.path.exists(wd):
    print("Cloning NSM repository...")
    os.makedirs(base_wd, exist_ok=True)
    %cd $base_wd
    !git clone https://github.com/aubricot/nsm.git
else:
    print("NSM directory already exists")

# Navigate to nsm directory and install
%cd $wd

# Install requirements
print("\n-----Installing requirements-----")
!python -m pip install -r requirements.txt

# Install NSM package
print("\n-----Installing NSM-----")
!pip install .

# Add to Python path
sys.path.insert(0, wd)
%cd $wd
print(f"\n\033[92mCurrent working directory set to: {os.getcwd()}\033[0m")

In [None]:
#@title Import libraries and define functions

# For rendering meshes
import pyvista as pv
pv.start_xvfb() # Enable PyVista for Colab
import plotly.graph_objects as go
import pymskt.mesh.meshes as meshes

# For working with ML
import torch
from NSM.helper_funcs import load_config, load_model_and_latents
from NSM.optimization import get_top_k_pcs
from NSM.helper_funcs import NumpyTransform, convert_ply_to_vtk
from NSM.optimization import (sample_near_surface,
    downsample_partial_pointcloud,
    optimize_latent_partial)
from NSM.datasets import SDFSamples
from NSM.mesh import create_mesh

# For working with data
import numpy as np
import random
import json

# Plot pyvista mesh interactively using plotly
def pv_to_plotly(mesh, color="deepskyblue", opacity=1.0):
    mesh = mesh.extract_surface().triangulate()
    faces = mesh.faces.reshape(-1, 4)
    return go.Mesh3d(x=mesh.points[:, 0], y=mesh.points[:, 1], z=mesh.points[:, 2],
                    i=faces[:, 1], j=faces[:, 2], k=faces[:, 3],
                    color=color, opacity=opacity, flatshading=False,
                    lighting=dict(ambient=0.12, diffuse=0.88, specular=0.05,
                                  roughness=0.9, fresnel=0.0),
                    lightposition=dict(x=0, y=0, z=2))

# Plot pyvista pointcloud interactively using plotly
def pv_points_to_plotly(mesh, color='red', size=4):
    pts = mesh.points
    return go.Scatter3d(x=pts[:, 0], y=pts[:, 1], z=pts[:, 2],
                        mode='markers',
                        marker=dict(size=size, color=color, opacity=1.0))

# Monkey patch for data types ----
from NSM.helper_funcs import safe_load_mesh_scalars, fixed_point_coords
meshes.Mesh.load_mesh_scalars = safe_load_mesh_scalars
meshes.Mesh.point_coords = property(fixed_point_coords)

import pymskt.mesh.meshTools as meshTools
_original_signed_distance_to_mesh = meshTools.pcu.signed_distance_to_mesh
def _signed_distance_to_mesh_patch(pts, points, faces):
    pts = np.asarray(pts, dtype=np.float64)     # force double precision
    points = np.asarray(points, dtype=np.float64)
    faces = np.asarray(faces, dtype=np.int32)   # ensure integer type for faces
    return _original_signed_distance_to_mesh(pts, points, faces)
meshTools.pcu.signed_distance_to_mesh = _signed_distance_to_mesh_patch
# End monkey patch ----

In [None]:
#@title Download models and meshes to appropriate folders

# Update these paths to point to your model and data
MODEL_DIR = "run_v44" # @param ["run_v44"] {"allow-input":true}
!gdown 1hRLyVdtqD2tF6wbE5m1Da0hLtHXiQ_oj
!unzip -o {MODEL_DIR}.zip -d {MODEL_DIR} && rm -f {MODEL_DIR}.zip

# Checkpoint to use
CKPT = "3000" # @param ["3000"] {"allow-input":true}
CKPT_fn = CKPT + '.pth'

# Fossil directory
fossil_dir = "fossils" # @param ["fossils"] {"allow-input":true}
os.makedirs(fossil_dir, exist_ok=True)
%cd $fossil_dir
!gdown 15c9e_LNPlWfIHXa3EcBR0fWHvdOSjiSl

# Modern vertebrae directory
vertebrae_dir = "vertebrae_meshes" # @param ["vertebrae_meshes"] {"allow-input":true}
%cd $wd
!rm -rf $vertebrae_dir # Delete demo vertebrae_meshes dir from nsm github
!gdown 1EaQJEfryoziFjdfYmI2-UPoF0wvhdnhS
!unzip -o {vertebrae_dir}.zip -d {vertebrae_dir} && rm -f {vertebrae_dir}.zip

# Output directory
OUTPUT_DIR = "shape_completion/predictions" # @param ["outputs"] {"allow-input":true}
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\n\033[92mSet up working directory and downloaded model and mesh files.")
print(f"Model directory: {MODEL_DIR}")
print(f"Checkpoint: {CKPT}")
print(f"Output directory: {OUTPUT_DIR}\033[0m")

## 2. Shape Completion

Complete the shape from a partial mesh (modern or fossil).


In [None]:
#@title Load model and latent codes

# Change to model directory
%cd $MODEL_DIR

# Load config
config = load_config(config_path='model_params_config.json')
device = config.get("device", "cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Paths to model and latent codes
LC_PATH = f'latent_codes/{CKPT}.pth'
MODEL_PATH = f'model/{CKPT}.pth'

# Load model and latents
print("Loading model and latents...")
model, latent_ckpt, latent_codes = load_model_and_latents(MODEL_PATH, LC_PATH, config, device)

# Compute statistics
mean_latent = latent_codes.mean(dim=0, keepdim=True)
latent_std = latent_codes.std().mean()
_, top_k_reg = get_top_k_pcs(latent_codes, threshold=0.99)

# Return to original directory
%cd $wd

print(f"\nLatent size: {config['latent_size']}")
print(f"Number of training samples: {len(latent_codes)}")

In [None]:
#@title Load mesh into latent space

# Pick a mesh
mesh_dir = fossil_dir # @param ["fossil_dir","vertebrae_dir"] {"type":"raw","allow-input":true}
mesh_path = random.choice(os.listdir(mesh_dir))
print(f"Mesh being loaded from directory: {mesh_dir}\n{mesh_path}\n")

# Setup output directory
mesh_name = os.path.splitext(os.path.basename(mesh_path))[0]
outfpath = os.path.join(OUTPUT_DIR, mesh_name)
os.makedirs(outfpath, exist_ok=True)
print(f"Saving results to output directory: {outfpath}")

# Convert PLY to VTK if needed
mesh_path = os.path.join(mesh_dir, mesh_path)
vert_fname = mesh_path
if '.ply' in mesh_path.lower():
    print("Converting PLY to VTK...")
    mesh, vert_fname = convert_ply_to_vtk(mesh_path, save=True)

# Setup dataset
print("\n-----Setting up dataset-----")
sdf_dataset = SDFSamples(
    list_mesh_paths=[vert_fname],
    multiprocessing=False,
    subsample=config["samples_per_object_per_batch"],
    print_filename=True,
    n_pts=config["n_pts_per_object"],
    p_near_surface=config['percent_near_surface'],
    p_further_from_surface=config['percent_further_from_surface'],
    sigma_near=config['sigma_near'],
    sigma_far=config['sigma_far'],
    rand_function=config['random_function'],
    center_pts=config['center_pts'],
    norm_pts=config['normalize_pts'],
    scale_method=config['scale_method'],
    reference_mesh=None,
    verbose=config['verbose'],
    save_cache=config['cache'],
    equal_pos_neg=config['equal_pos_neg'],
    fix_mesh=config['fix_mesh'])

# Get SDF data
sdf_sample = sdf_dataset[0]
sample_dict, _ = sdf_sample
points = sample_dict['xyz'].to(device)
sdf_vals = sample_dict['gt_sdf']

# Downsample partial pointcloud
print("\n-----Preparing partial pointcloud-----")
partial_pts = downsample_partial_pointcloud(vert_fname, 180)
partial_pts = torch.tensor(partial_pts, dtype=torch.float32)
partial_cloud = pv.PolyData(partial_pts.cpu().numpy())
partial_cloud.save(os.path.join(outfpath, f"{mesh_name}_partial_input.vtk"))

# Sample points with SDF values
partial_pts, sdfs = sample_near_surface(
    partial_pts, eps=0.005, fraction_nonzero=0.4,
    fraction_far=0.05, far_eps=0.05)

# Optimize latents
print("\n-----Optimizing latents (Phase 1: Coarse)-----")
latent_partial, _ = optimize_latent_partial(
    model, partial_pts, sdfs, config['latent_size'],
    mean_latent=mean_latent, latent_init=latent_codes, top_k=top_k_reg,
    iters=5000, lr=1e-4, lambda_reg=1e-3, clamp_val=2.0,
    latent_std=latent_std, scheduler_step=800, scheduler_gamma=0.8,
    batch_inference_size=32768, multi_stage=False, device=device)

print("\n-----Optimizing latents (Phase 2: Refinement)-----")
latent_partial, _ = optimize_latent_partial(
    model, partial_pts, sdfs, config['latent_size'],
    latent_init=latent_partial, top_k=top_k_reg,
    iters=8000, lr=1.3e-5, lambda_reg=7e-5, clamp_val=None,
    latent_std=latent_std, scheduler_step=800, scheduler_gamma=0.7,
    batch_inference_size=32768, multi_stage=True, device=device)

print("\n\033[92mLatent optimization complete\033[0m")

In [None]:
#@title Complete and reconstruct mesh

# Reconstruction parameters
recon_grid_origin = 1.0
n_pts_per_axis = 256 # @param ["256","128","384"] {"type":"raw"}
voxel_origin = (-recon_grid_origin, -recon_grid_origin, -recon_grid_origin)
voxel_size = (recon_grid_origin * 2) / (n_pts_per_axis - 1)
offset = np.array([0.0, 0.0, 0.0])
scale = 1.0
icp_transform = NumpyTransform(np.eye(4))
objects = 1

# Reconstruct mesh
print("\n\033[93m-----Reconstructing mesh-----\033[0m")
with torch.no_grad():
    mesh_out = create_mesh(
        decoder=model, latent_vector=latent_partial,
        n_pts_per_axis=n_pts_per_axis,
        voxel_origin=voxel_origin, voxel_size=voxel_size,
        path_original_mesh=None,
        offset=offset, scale=scale, icp_transform=icp_transform,
        objects=objects, verbose=True, device=device,
        smooth=1.0, scale_to_original_mesh=False)

# Ensure it's PyVista PolyData
if isinstance(mesh_out, list):
    mesh_out = mesh_out[0]
if not isinstance(mesh_out, pv.PolyData):
    mesh_pv = mesh_out.extract_geometry()
else:
    mesh_pv = mesh_out

# Clean and triangulate
mesh_pv = mesh_pv.clean()
mesh_pv = mesh_pv.triangulate()

# Save mesh
output_path = os.path.join(outfpath, f"{mesh_name}_shape_completion.vtk")
color = np.array([112, 215, 222], dtype=np.uint8)  # RGB color
rgb = np.tile(color, (mesh_pv.n_points, 1))
mesh_pv.point_data.clear()
mesh_pv.point_data['Colors'] = rgb
mesh_pv.save(output_path)

print(f"\n\033[92mCompleted mesh saved to: {output_path}\033[0m")
print(f"Number of points: {mesh_pv.n_points}")
print(f"Number of faces: {mesh_pv.n_faces_strict}")

## 3. Inspect Results

In [None]:
#@title Plot the original mesh

# Read mesh
original_mesh = pv.read(os.path.join(mesh_dir, f"{mesh_name}.vtk"))
original_mesh.compute_normals(inplace=True)

# Plot figure
fig = go.Figure()
trace = pv_to_plotly(original_mesh, 'goldenrod', 1)
trace.name = "Original mesh"
fig.add_trace(trace)
for trace in fig.data:
    trace.showlegend = True
fig.update_layout(title=dict(text=f"Original Mesh (before completion)<br>{mesh_name}",
                             x=0.5, y=0.95, xanchor="center", yanchor="top"),
                  showlegend=True,
                  scene_aspectmode='data',
                  legend=dict(x=1.02, y=1, bgcolor="rgba(255,255,255,0.7)",
                              bordercolor="black", borderwidth=1),
                  margin=dict(l=10, r=10, b=10, t=80))
fig.show()

In [None]:
#@title Plot the completed mesh vs sampled point cloud

# Read mesh
partial_mesh = pv.read(os.path.join(outfpath, f"{mesh_name}_partial_input.vtk"))
completed_mesh = pv.read(output_path)

# Plot figure
fig = go.Figure()
trace = (pv_to_plotly(completed_mesh, 'deepskyblue', 1)) # Completed surface
trace.name = "Completed mesh"
fig.add_trace(trace)
trace = (pv_points_to_plotly(partial_mesh, 'darkseagreen', size=5)) # Partial point cloud
trace.name = "Partial point cloud"
fig.add_trace(trace)
for trace in fig.data:
    trace.showlegend = True
fig.update_layout(title=dict(text=f"Completed mesh<br>{mesh_name}",
                             x=0.5, y=0.95, xanchor="center", yanchor="top"),
                  showlegend=True,
                  scene_aspectmode='data',
                  legend=dict(x=1.02, y=1, bgcolor="rgba(255,255,255,0.7)",
                              bordercolor="black", borderwidth=1),
                  margin=dict(l=10, r=10, b=10, t=80))
fig.show()