<a href="https://colab.research.google.com/github/aubricot/nsm/blob/main/demos/classification_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**Squamate Vertebra Classification Demo**   
*Last edited 19 Jan 2026*

This notebook demonstrates classification of vertebra species and position (modern and fossil) using a trained Neural Shape Model (NSM; Gatti et al. 2025, Park et al. 2019). It can be run fully in demo mode without connecting to your Google Drive. Adjust parameters using form fields and make sure your runtime environment is set to run on GPU. Full repository code is available at [aubricot/nsm on GitHub](https://github.com/aubricot/nsm).

Modern vertebra meshes are derived from micro-CT data produced by the oVert Initiative (Blackburn et al. 2024). Fossil vertebra were downloaded from MorphoSource ([UF546657](https://doi.org/10.17602/M2/M600663); [UF271967](https://n2t.net/ark:/87602/m4/M69199)). All vertebrae were aligned and scaled using ATLAS before training (Porto et al. 2026).


**References**
* Blackburn et al. 2024, BioScience. https://doi.org/10.1093/biosci/biad120
* Gatti et al. 2025, IEEE TMI. https://doi.org/10.1109/tmi.2024.3485613
* Park et al. 2019, CVPR. https://doi.org/10.48550/arXiv.1901.05103
* Porto et al. 2026, in prep. https://github.com/agporto/ATLAS

## 1. Installs & Imports

In [None]:
#@title Check GPU and CUDA info - make sure Colab Runtime set to GPU
from psutil import virtual_memory

# Check GPU and CUDA
!nvcc --version
gpu = !nvidia-smi
gpu = '\n'.join(gpu)
print('\033[91mNot connected to a GPU\033[0m' if 'failed' in gpu else gpu)

# Check RAM
ram = virtual_memory().total / 1e9
print(f'\033[92mYour runtime has {ram:.1f} GB of RAM\033[0m\n')

In [None]:
#@title Choose where to save results
import os
import sys

# Use dropdown menu on right
save = "in Colab runtime (files deleted after each session)" #@param ["in my Google Drive", "in Colab runtime (files deleted after each session)"]

# Mount google drive to export image tagging file(s)
if 'Google Drive' in save:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=False)

# Type in the path to your project wd in form field on right
base_wd = "/content/drive/MyDrive" # @param ["/content/drive/MyDrive/nsm"] {"allow-input":true}
wd = base_wd + "/nsm"
print(f"\033[92mWorking directory set to: \n{wd}\033[0m")

In [None]:
#@title Set up environment and install NSM
import os
import sys

# Install PyTorch with CUDA support (Colab typically has CUDA 11.8 or 12.x)
print("\033[92mSetting up environment...\033[0m")
print("\n\033[33m-----This will take a few minutes----\033[0m")
!pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124

# Install other dependencies
!pip install pyvista mskt open3d scikit-learn matplotlib pandas numpy scipy
!pip install ipywidgets
!pip install nibabel scikit-image opencv-python open3d

# Clone NSM repository
if not os.path.exists(wd):
    print("Cloning NSM repository...")
    os.makedirs(base_wd, exist_ok=True)
    %cd $base_wd
    !git clone https://github.com/aubricot/nsm.git
else:
    print("NSM directory already exists")

# Navigate to nsm directory and install
%cd $wd

# Install requirements
print("\n-----Installing requirements-----")
!python -m pip install -r requirements.txt

# Install NSM package
print("\n-----Installing NSM-----")
!pip install .

# Add to Python path
sys.path.insert(0, wd)
%cd $wd
print(f"\n\033[92mCurrent working directory set to: {os.getcwd()}\033[0m")

In [None]:
#@title Import libraries and define functions

# For rendering meshes
import pyvista as pv
pv.start_xvfb() # Enable PyVista for Colab
import plotly.graph_objects as go
import pymskt.mesh.meshes as meshes
import vtk

# For working with ML
import torch
import torch.nn.functional as F
from NSM.helper_funcs import load_config, load_model_and_latents
from NSM.optimization import get_top_k_pcs
from NSM.helper_funcs import NumpyTransform, convert_ply_to_vtk
from NSM.optimization import (sample_near_surface,
    downsample_partial_pointcloud,
    optimize_latent_partial)
from NSM.datasets import SDFSamples
from NSM.mesh import create_mesh

# For working with data
import numpy as np
import pandas as pd
import random
import json
import re
from pathlib import Path
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

# Identify novel meshes from latent space
from NSM.models import TriplanarDecoder
from NSM.mesh import get_sdfs
from NSM.helper_funcs import NumpyTransform, load_config, load_model_and_latents, convert_ply_to_vtk, get_sdfs, fixed_point_coords, safe_load_mesh_scalars, extract_species_prefix, parse_labels_from_filepaths
from NSM.optimization import pca_initialize_latent, get_top_k_pcs, find_similar, find_similar_cos, optimize_latent

# Plot pyvista mesh interactively using plotly
def pv_to_plotly(mesh, color="deepskyblue", opacity=1.0):
    mesh = mesh.extract_surface().triangulate()
    faces = mesh.faces.reshape(-1, 4)
    return go.Mesh3d(x=mesh.points[:, 0], y=mesh.points[:, 1], z=mesh.points[:, 2],
                    i=faces[:, 1], j=faces[:, 2], k=faces[:, 3],
                    color=color, opacity=opacity, flatshading=False,
                    lighting=dict(ambient=0.12, diffuse=0.88, specular=0.05,
                                  roughness=0.9, fresnel=0.0),
                    lightposition=dict(x=0, y=0, z=2))

def plot_predictions(dim_reduced_coords, similar_coords, novel_coord, filepaths, out_fn):
        plt.figure(figsize=(8, 6))
        plt.scatter(dim_reduced_coords[:, 0], dim_reduced_coords[:, 1], color='gray', alpha=0.3, label='Training Meshes')
        # Plot most similar (1st one) in pink
        plt.scatter(similar_coords[0, 0], similar_coords[0, 1], color='hotpink', s=80, label='Most Similar')
        # Plot next 4 similar in blue
        if len(similar_coords) > 1:
            plt.scatter(similar_coords[1:, 0], similar_coords[1:, 1], color='blue', s=60, label='Other Top-5 Similar')
        # Plot novel mesh in red
        plt.scatter(*novel_coord, color='red', s=80, label='Novel Mesh')
        # Aannotate each of the top-5 similar meshes
        for idx, (x, y) in zip(similar_ids, similar_coords):
            plt.text(x, y, filepaths[idx].split('.')[0], fontsize=6, color='black')
        plt.title("Latent Space Visualization (PCA)")
        plt.xlabel("Component 1")
        plt.ylabel("Component 2")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(outfpath + "/" + out_fn, dpi=300)
        plt.show()
        plt.close()

# Monkey patch for data types ----
from NSM.helper_funcs import safe_load_mesh_scalars, fixed_point_coords
meshes.Mesh.load_mesh_scalars = safe_load_mesh_scalars
meshes.Mesh.point_coords = property(fixed_point_coords)

import pymskt.mesh.meshTools as meshTools
_original_signed_distance_to_mesh = meshTools.pcu.signed_distance_to_mesh
def _signed_distance_to_mesh_patch(pts, points, faces):
    pts = np.asarray(pts, dtype=np.float64)     # force double precision
    points = np.asarray(points, dtype=np.float64)
    faces = np.asarray(faces, dtype=np.int32)   # ensure integer type for faces
    return _original_signed_distance_to_mesh(pts, points, faces)
meshTools.pcu.signed_distance_to_mesh = _signed_distance_to_mesh_patch
# End monkey patch ----

In [None]:
#@title Download models and meshes to appropriate folders

# Update these paths to point to your model and data
MODEL_DIR = "run_v44" # @param ["run_v44"] {"allow-input":true}
!gdown 1hRLyVdtqD2tF6wbE5m1Da0hLtHXiQ_oj
!unzip -o {MODEL_DIR}.zip -d {MODEL_DIR} && rm -f {MODEL_DIR}.zip

# Checkpoint to use
CKPT = "3000" # @param ["3000"] {"allow-input":true}
CKPT_fn = CKPT + '.pth'

# Fossil directory
fossil_dir = "fossils" # @param ["fossils"] {"allow-input":true}
#os.makedirs(fossil_dir, exist_ok=True)
#%cd $fossil_dir
!gdown 1UgKYDj4d5d0D4M8MfHFAh-IkW-dkmujf
!unzip -o {fossil_dir}.zip -d {fossil_dir} && rm -f {fossil_dir}.zip

# Modern vertebrae directory
vertebrae_dir = "vertebrae_meshes" # @param ["vertebrae_meshes"] {"allow-input":true}
%cd $wd
!rm -rf $vertebrae_dir # Delete demo vertebrae_meshes dir from nsm github
!gdown 1EaQJEfryoziFjdfYmI2-UPoF0wvhdnhS
!unzip -o {vertebrae_dir}.zip -d {vertebrae_dir} && rm -f {vertebrae_dir}.zip

# Output directory
OUTPUT_DIR = "classification" # @param ["outputs"] {"allow-input":true}
os.makedirs(OUTPUT_DIR, exist_ok=True)
%cd $OUTPUT_DIR
!gdown 19V3DlpthWjI_5ttmmxepeY20iI87LB0N
OUTPUT_DIR = OUTPUT_DIR + "/predictions"
%cd $wd
!unzip -o {OUTPUT_DIR}.zip -d {OUTPUT_DIR} && rm -f {OUTPUT_DIR}.zip

print(f"\n\033[92mSet up working directory and downloaded model and mesh files.")
print(f"Model directory: {MODEL_DIR}")
print(f"Checkpoint: {CKPT}")
print(f"Output directory: {OUTPUT_DIR}\033[0m")

## 2. Classification

Classify the species and spinal position of a novel squamate vertebra mesh (modern or fossil).


In [None]:
#@title Load model and latent codes

# Change to model directory
%cd $MODEL_DIR

# Load config
config = load_config(config_path='model_params_config.json')
device = config.get("device", "cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Training dataset mesh names
train_paths = config['list_mesh_paths']
all_vtk_files = [os.path.basename(f) for f in train_paths]

# Paths to model and latent codes
LC_PATH = f'latent_codes/{CKPT}.pth'
MODEL_PATH = f'model/{CKPT}.pth'

# Load model and latents
print("Loading model and latents...")
model, latent_ckpt, latent_codes = load_model_and_latents(MODEL_PATH, LC_PATH, config, device)

# Compute statistics
mean_latent = latent_codes.mean(dim=0, keepdim=True)
latent_std = latent_codes.std().mean()
_, top_k_reg = get_top_k_pcs(latent_codes, threshold=0.99)

# Return to original directory
%cd $wd

print(f"\nLatent size: {config['latent_size']}")
print(f"Number of training samples: {len(latent_codes)}")

In [None]:
#@title Load mesh into latent space

# Pick a mesh
mesh_dir = fossil_dir # @param ["fossil_dir","vertebrae_dir"] {"type":"raw","allow-input":true}
mesh_path = random.choice(os.listdir(mesh_dir))
print(f"Mesh being loaded from directory: {mesh_dir}\n{mesh_path}\n")

# Setup output directory
mesh_name = os.path.splitext(os.path.basename(mesh_path))[0]
outfpath = os.path.join(OUTPUT_DIR, mesh_name)
os.makedirs(outfpath, exist_ok=True)
print(f"Saving results to output directory: {outfpath}")

# Set up output path for novel mesh
output_path = os.path.join(outfpath, f"{mesh_name}_decoded_novel_pca_regularized_95pct_cos.vtk")

# Convert PLY to VTK if needed
mesh_path = os.path.join(mesh_dir, mesh_path)
vert_fname = mesh_path
if '.ply' in mesh_path.lower():
    print("Converting PLY to VTK...")
    mesh, vert_fname = convert_ply_to_vtk(mesh_path, save=True)

# Setup dataset
summary_log = []
print("\n-----Setting up dataset-----")
sdf_dataset = SDFSamples(
    list_mesh_paths=[vert_fname],
    multiprocessing=False,
    subsample=config["samples_per_object_per_batch"],
    print_filename=True,
    n_pts=config["n_pts_per_object"],
    p_near_surface=config['percent_near_surface'],
    p_further_from_surface=config['percent_further_from_surface'],
    sigma_near=config['sigma_near'],
    sigma_far=config['sigma_far'],
    rand_function=config['random_function'],
    center_pts=config['center_pts'],
    norm_pts=config['normalize_pts'],
    scale_method=config['scale_method'],
    reference_mesh=None,
    verbose=config['verbose'],
    save_cache=config['cache'],
    equal_pos_neg=config['equal_pos_neg'],
    fix_mesh=config['fix_mesh'])

# Get SDF data
sdf_sample = sdf_dataset[0]
sample_dict, _ = sdf_sample
points = sample_dict['xyz'].to(device)
sdf_vals = sample_dict['gt_sdf']

# Optimize latents (DeepSDF has no encoder, so must use optimization to encode novel data)
print("\n-----Optimizing latents-----")
latent_novel = optimize_latent(model, points, sdf_vals, config['latent_size'], top_k_reg, mean_latent, latent_codes)
print("Translated novel mesh into latent space!")

# Classify vertebra

# Find most similar latents (Compare to existing latents)
print("\n-----Finding most similar meshes-----")
similar_ids, distances = find_similar_cos(latent_novel, latent_codes, top_k=5, n_std=2, device=device)

# Write most similar meshes to txt file
sim_mesh_fpath = outfpath + '/' + 'similar_meshes_pca_regularized_95pct_cos.txt'
with open(sim_mesh_fpath, "w") as f:
    print(f"Most similar mesh indices to file: {os.path.basename(vert_fname)}\n")
    f.write(f"Most similar mesh indices to file: {os.path.basename(vert_fname)}:\n")
    header = "Name: , Index: , Distance:  "
    f.write(header + "\n")
    for i, d in zip(similar_ids, distances):
          # Now construct the line using the integer i
          line = f"{all_vtk_files[i]}, {i}, {d:.4f}"
          print(line)
          f.write(line + "\n")
print(f"\n\033[92mMost similar meshes written to file: {sim_mesh_fpath}\033[0m")

In [None]:
# Inspect novel latent using clustering analysis

# PCA Plot
# Data loading
latents = latent_codes.cpu().numpy()
pca = PCA(n_components=2)
coords_2d = pca.fit_transform(latents)
novel_coord = pca.transform(latent_novel.cpu().numpy())[0]
similar_coords = coords_2d[similar_ids]
plot_predictions(coords_2d, similar_coords, novel_coord, all_vtk_files, out_fn="latent_space_pca_pca_regularized_95pct_cos.png")
print('\n\n\n')

# t-SNE Plot
# Data loading
latent_novel_np = latent_novel.detach().cpu().numpy()
latents_with_novel = np.vstack([latents, latent_novel_np])
tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, random_state=42)
coords_with_novel = tsne.fit_transform(latents_with_novel)
train_coords = coords_with_novel[:-1]
novel_coord = coords_with_novel[-1]
similar_coords = train_coords[similar_ids]
plot_predictions(train_coords, similar_coords, novel_coord, all_vtk_files, "latent_space_tsne_pca_regularized_95pct_cos.png")

## 3. Inspect Results

In [None]:
#@title View the top-5 most similar meshes

# Inspect head of summary_matches csv file
df = pd.read_csv(sim_mesh_fpath, header=1)
df.head()

In [None]:
#@title Plot the original mesh

# Read mesh
orig_mesh_name = mesh_name
original_mesh = pv.read(os.path.join(mesh_dir, f"{mesh_name}.vtk"))
original_mesh.compute_normals(inplace=True)

# Plot figure
fig = go.Figure()
trace = pv_to_plotly(original_mesh, 'goldenrod', 1)
trace.name = "Original mesh"
fig.add_trace(trace)
for trace in fig.data:
    trace.showlegend = True
fig.update_layout(title=dict(text=f"Original Mesh (before completion)<br>{mesh_name}",
                             x=0.5, y=0.95, xanchor="center", yanchor="top"),
                  showlegend=True,
                  scene_aspectmode='data',
                  legend=dict(x=1.02, y=1, bgcolor="rgba(255,255,255,0.7)",
                              bordercolor="black", borderwidth=1),
                  margin=dict(l=10, r=10, b=10, t=80))
fig.show()

In [None]:
#@title Randomly select and plot meshes from the top-5 most similar

# Loop through each mesh name in the DataFrame
mesh_list = []
for mesh_name in df['Name: ']:
    # Check if the mesh file exists in the directory
    if os.path.isfile(os.path.join(vertebrae_dir, mesh_name)):
        # If the file exists, append the mesh name to mesh_list
        mesh_list.append(mesh_name)

# Print the mesh list with the files that exist in the directory
print("Meshes found in directory:", mesh_list)

# Read mesh
mesh_name = random.choice(mesh_list)
top_mesh = pv.read(os.path.join(vertebrae_dir, f"{mesh_name}"))
print("Inspecting randomly chosen similar mesh: ", top_mesh)
top_mesh.compute_normals(inplace=True)

# Plot figure
fig = go.Figure()
trace = pv_to_plotly(top_mesh, 'deepskyblue', 1)
trace.name = "Top-5 Similar Mesh"
fig.add_trace(trace)
for trace in fig.data:
    trace.showlegend = True
fig.update_layout(title=dict(text=f"Top-5 Similar Mesh<br>{mesh_name}",
                             x=0.5, y=0.95, xanchor="center", yanchor="top"),
                  showlegend=True,
                  scene_aspectmode='data',
                  legend=dict(x=1.02, y=1, bgcolor="rgba(255,255,255,0.7)",
                              bordercolor="black", borderwidth=1),
                  margin=dict(l=10, r=10, b=10, t=80))
fig.show()