## Interpolate known ladnmarks from mean/median specimen to novel/generated meshes
---
*Last edited 4 Nov 2025 by K. Wolcott*   
Use this notebook to interpret how model is interpreting morphology and landmark positions by invesitgating how landmarks are transferred onto novel/generated meshes from randomly sampled latents.   

In [None]:
# Define working directory for latent codes and model checkpoint files
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0" # Optional TO DO: Set which GPU to use (0 or 1)

# Set working directory
TRAIN_DIR = "run_v51" # TO DO: Choose training directory containing model ckpt and latent codes
cwd = os.getcwd()
if TRAIN_DIR not in cwd:
    os.chdir(cwd + '/' + TRAIN_DIR)

# Build model checkpoint and latent code paths
CKPT = '1500' # TO DO: Choose the ckpt value you want to analyze results for
LC_PATH = 'latent_codes' + '/' + CKPT + '.pth'
MODEL_PATH = 'model' + '/' + CKPT + '.pth'

print("\033[92mWorking directory set to: ", os.getcwd())

In [None]:
# Define functions
from NSM.mesh.interpolate import interpolate_points, interpolate_mesh, interpolate_common
from NSM.datasets.sdf_dataset import get_pts_center_and_scale
from NSM.mesh import create_mesh
from NSM.models import TriplanarDecoder
from scipy.spatial import cKDTree
import torch
import torch.nn.functional as F

import pyvista as pv
import vtk
import open3d as o3d
from itkwidgets import view

import pymskt as mskt
import matplotlib.pyplot as plt

from pathlib import Path
import numpy as np
import json
import pandas as pd
import shutil

# For setting up icp transform in mesh generation
class NumpyTransform:
    def __init__(self, matrix):
        self.matrix = matrix
    def GetMatrix(self):
        vtk_mat = vtk.vtkMatrix4x4()
        for i in range(4):
            for j in range(4):
                vtk_mat.SetElement(i, j, self.matrix[i, j])
        return vtk_mat

# Load model config file
def load_config(config_path='model_params_config.json'):
    try:
        with open(config_path, 'r') as f:
            config = json.load(f)
        print(f"\033[92mLoaded config from {config_path}\033[0m")
        return config
    except FileNotFoundError:
        raise FileNotFoundError(f"Error: model_params_config.json not found at {config_path}")

# Load trained model and latents
def load_model_and_latents(MODEL_PATH, LC_PATH, config, device):
    # Load model
    triplane_args = {
        'latent_dim': config['latent_size'],
        'n_objects': config['objects_per_decoder'],
        'conv_hidden_dims': config['conv_hidden_dims'],
        'conv_deep_image_size': config['conv_deep_image_size'],
        'conv_norm': config['conv_norm'], 
        'conv_norm_type': config['conv_norm_type'],
        'conv_start_with_mlp': config['conv_start_with_mlp'],
        'sdf_latent_size': config['sdf_latent_size'],
        'sdf_hidden_dims': config['sdf_hidden_dims'],
        'sdf_weight_norm': config['weight_norm'],
        'sdf_final_activation': config['final_activation'],
        'sdf_activation': config['activation'],
        'sdf_dropout_prob': config['dropout_prob'],
        'sum_sdf_features': config['sum_conv_output_features'],
        'conv_pred_sdf': config['conv_pred_sdf'],
    }
    model = TriplanarDecoder(**triplane_args)
    model_ckpt = torch.load(MODEL_PATH, map_location=device)
    model.load_state_dict(model_ckpt['model'])
    model.to(device)
    model.eval()
    # Load latents
    latent_ckpt = torch.load(LC_PATH, map_location=device)
    latent_codes = latent_ckpt['latent_codes']['weight'].detach().cpu()
    return model, latent_ckpt, latent_codes

# Load landmarks file (.mrk.json)
def load_mrk_json(path):
    with open(path, 'r') as f:
        data = json.load(f)
    if "markups" not in data or len(data["markups"]) == 0:
        raise ValueError(f"No 'markups' found in {path}")
    markups = data["markups"][0]
    points = []
    labels = []
    for cp in markups.get("controlPoints", []):
        pos = cp.get("position", None)
        if pos is not None:
            points.append(pos)
            labels.append(cp.get("label", None))
    points = np.array(points, dtype=np.float32)
    return points, labels

# Convert vtk mesh to open3d
def vtk_to_o3d(vtk_file):
    # Load the VTK file using PyVista
    vtk_mesh = pv.read(vtk_file)
    # Extract vertices and faces
    vertices = vtk_mesh.points
    faces = vtk_mesh.faces.reshape((-1, 4))[:, 1:]  # Reshape faces to get rid of the first number
    # Create Open3D TriangleMesh from the extracted vertices and faces
    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(vertices)
    mesh.triangles = o3d.utility.Vector3iVector(faces)
    # Optionally, check if normals exist in the VTK file
    if vtk_mesh.point_data.get("Normals") is not None:
        mesh.vertex_normals = o3d.utility.Vector3dVector(vtk_mesh.point_data["Normals"])
    return mesh

# Snap landmarks to surface using nearest neighbor
def project_landmarks_to_surface(landmarks, target_mesh):
    mesh_points = np.asarray(target_mesh.points)
    kdtree = cKDTree(mesh_points)
    _, idx = kdtree.query(landmarks)
    projected = mesh_points[idx]
    return projected

# Show original mesh and landmarks and generated mesh with interpolated landmarks in external, interactive window
def visualize_landmark_interpolation_external(original_mesh, mesh, original_pts, interp_pts, labels=None,
                                                animate=False, output_path=None, n_frames=120, off_screen=False, plot_title="Landmark Interpolation"):
    # Unwrap MeshWrapper if needed
    if hasattr(mesh, "mesh"):
        mesh = mesh.mesh
    # Ensure both meshes are PyVista PolyData
    mesh_pv = mesh if isinstance(mesh, pv.PolyData) else mesh.extract_geometry()
    original_mesh_pv = (original_mesh if isinstance(original_mesh, pv.PolyData) else original_mesh.extract_geometry())
    # Compute normals for smoother rendering
    mesh_pv = mesh_pv.compute_normals(cell_normals=False, point_normals=True)
    original_mesh_pv = original_mesh_pv.compute_normals(cell_normals=False, point_normals=True)
    # Landmarks as PolyData
    orig = pv.PolyData(original_pts)
    interp = pv.PolyData(interp_pts)
    # Lines connecting original → interpolated
    lines_polydata = pv.MultiBlock([pv.Line(original_pts[i], interp_pts[i]) for i in range(len(original_pts))]).combine()
    # Open Interactive window
    plotter = pv.Plotter(window_size=(1024, 768), off_screen=off_screen, notebook=False,)
    plotter.add_text(plot_title, font_size=12)
    # Add meshes and landmarks
    plotter.add_mesh(original_mesh_pv, color="lightgray", opacity=0.4, smooth_shading=True, show_edges=False)
    plotter.add_mesh(mesh_pv, color="blue", opacity=0.4, smooth_shading=True, show_edges=False)
    plotter.add_mesh(orig, color="red", point_size=14, render_points_as_spheres=True)
    plotter.add_mesh(interp, color="green", point_size=14, render_points_as_spheres=True)
    plotter.add_mesh(lines_polydata, color="orange", line_width=2)
    # Optional labels
    if labels is not None:
        for pt, label in zip(original_pts, labels):
            plotter.add_point_labels([pt], [label], font_size=10, text_color="black")
    # Show in a standalone GUI window
    plotter.show_axes()
    plotter.show_grid()
    if not animate:
        plotter.show(interactive=True, auto_close=True)
    # Save animated path (rotating meshes with LMs)
    if animate:
        if output_path is None:
            raise ValueError("output_path must be provided when animate=True")
        if output_path.suffix.lower() == ".gif":
            plotter.open_gif(str(output_path))
        else:
            plotter.open_movie(str(output_path), framerate=30, quality=9)
        # Initialize camera once for reproducibility
        plotter.camera.zoom(1.2)
        for _ in range(n_frames):
            plotter.camera.azimuth += 360 / n_frames
            plotter.write_frame()
        plotter.close()
        return

# Show set of landmarks on a mesh
def visualize_landmarks_on_mesh(original_mesh, original_pts):
    # Ensure both meshes are PyVista PolyData
    original_mesh_pv = original_mesh if isinstance(original_mesh, pv.PolyData) else original_mesh.extract_geometry()
    # Compute normals for smoother rendering
    original_mesh_pv = original_mesh_pv.compute_normals(cell_normals=False, point_normals=True, inplace=False)
    # Landmarks as PolyData
    orig = pv.PolyData(original_pts)
    # Open Interactive window
    plotter = pv.Plotter(window_size=[1024, 768], notebook=False, off_screen=False)
    plotter.add_text("Landmark Interpolation", font_size=12)
    # Add meshes and landmarks
    plotter.add_mesh(original_mesh_pv, color="lightgray", opacity=0.4, smooth_shading=True, show_edges=False)
    plotter.add_mesh(orig, color="red", point_size=14, render_points_as_spheres=True)
    plotter.show_axes()
    plotter.show_grid()
    # Show in a standalone GUI window
    plotter.show(interactive=True, auto_close=True)

# Normalize mesh to unit sphere to match NSM backend scaling
def normalize_mesh_to_unit_sphere(mesh, reference_points=None):
    # Extract vertices
    if isinstance(mesh, pv.PolyData):
        vertices = mesh.points.copy()
    else:  # assume numpy array of shape (N, 3)
        vertices = mesh.copy()
    if reference_points is None:
        reference_points = vertices
    # Compute center and scale
    center = np.mean(reference_points, axis=0)
    shifted = vertices - center
    scale = np.max(np.linalg.norm(reference_points - center, axis=1))
    normalized_vertices = shifted / scale
    # If PyVista mesh, return new mesh
    if isinstance(mesh, pv.PolyData):
        normalized_mesh = mesh.copy()
        normalized_mesh.points = normalized_vertices
        return normalized_mesh, center, scale
    else:
        return normalized_vertices, center, scale

# Convert open3d to pyvista mesh    
def o3d_to_pv(mesh_o3d):
    # Get vertices and faces as numpy arrays
    vertices = np.asarray(mesh_o3d.vertices)
    faces = np.asarray(mesh_o3d.triangles)
    # PyVista expects faces in a flattened format: [3, v0, v1, v2, 3, ...]
    faces_flat = np.hstack([np.full((faces.shape[0], 1), 3), faces]).flatten()
    # Create PyVista mesh
    mesh_pv = pv.PolyData(vertices, faces_flat)
    return mesh_pv

# Save landmark file for rendering in 3D slicer (.mrk.json)
def save_mrk_json(points, labels, output_filename):
    control_points = []
    for i, (pt, label) in enumerate(zip(points, labels), start=1):
        # Ensure pt is a list or numpy array of coordinates
        if isinstance(pt, np.ndarray):
            pt = pt.tolist()  # If pt is a numpy array, convert to a list
        elif isinstance(pt, list):
            pass  # pt is already a list, no conversion needed
        else:
            raise ValueError(f"Expected pt to be a list or numpy array, got {type(pt)}")
        cp = {
            "id": str(i),
            "label": label,
            "description": "",
            "associatedNodeID": "",
            "position": pt,  # No need to call .tolist() here if it's already a list
            "orientation": [-1.0, -0.0, -0.0, -0.0, -1.0, -0.0, 0.0, 0.0, 1.0],
            "selected": True,
            "locked": True,  # Lock the position so it can't be moved in Slicer
            "lockedPosition": True,  # Optional: If you want explicit "lockedPosition" field
            "visibility": True,
            "positionStatus": "defined"}  # Ensure the position is defined and not undefined
        control_points.append(cp)
    markups_data = {
        "@schema": "https://raw.githubusercontent.com/slicer/slicer/master/Modules/Loadable/Markups/Resources/Schema/markups-schema-v1.0.3.json#",
        "markups": [{
                "type": "Fiducial",
                "coordinateSystem": "LPS",
                "coordinateUnits": "mm",
                "locked": False,  # Keep the overall markup locked status as False for editing purposes
                "fixedNumberOfControlPoints": False,
                "labelFormat": "%N-%d",
                "lastUsedControlPointNumber": len(control_points),
                "controlPoints": control_points,
                "measurements": [],
                "display": {
                    "visibility": False,
                    "opacity": 1.0,
                    "color": [0.4, 1.0, 1.0],
                    "selectedColor": [1.0, 0.5000076295109483, 0.5000076295109483],
                    "activeColor": [0.4, 1.0, 0.0],
                    "propertiesLabelVisibility": False,
                    "pointLabelsVisibility": False,
                    "textScale": 3.0,
                    "glyphType": "Sphere3D",
                    "glyphScale": 3.0,
                    "glyphSize": 5.0,
                    "useGlyphScale": True,
                    "sliceProjection": False,
                    "sliceProjectionUseFiducialColor": True,
                    "sliceProjectionOutlinedBehindSlicePlane": False,
                    "sliceProjectionColor": [1.0, 1.0, 1.0],
                    "sliceProjectionOpacity": 0.6,
                    "lineThickness": 0.2,
                    "lineColorFadingStart": 1.0,
                    "lineColorFadingEnd": 10.0,
                    "lineColorFadingSaturation": 1.0,
                    "lineColorFadingHueOffset": 0.0,
                    "handlesInteractive": False,
                    "translationHandleVisibility": True,
                    "rotationHandleVisibility": True,
                    "scaleHandleVisibility": False,
                    "interactionHandleScale": 3.0,
                    "snapMode": "toVisibleSurface"}}]}
    # Save to JSON file
    with open(output_filename, 'w') as f:
        json.dump(markups_data, f, indent=4)
    print(f"Saved to {output_filename}")

# Interpolate landmark points from latent 1 to latent 2
def interpolate_points(
    model, latent1, latent2, n_steps=100, points1=None, surface_idx=0, verbose=False, spherical=True, smooth_type="taubin"):
    return interpolate_common(
        model, latent1, latent2, n_steps, points1, surface_idx, verbose, spherical, is_mesh=False, smooth_type="taubin")

# Get indices for interpolation by filename or randomly
def get_indices(train_paths, random_sample=True, n_meshes=5, string_to_match=None):
    # If randomly sampling
    if random_sample == True and string_to_match == None:
        if string_to_match is not None:
            raise ValueError("If randomly sampling, string_to_match will not be considered. If matching by string, random_sample must be set to False.")
        print(f"Randomly sampling {n_meshes} indices from latent codes")
        indices = np.random.randint(0, len(train_paths), n_meshes)
    # If matching by string
    else:
        if string_to_match is None:
            raise ValueError("Target string must be provided when not sampling randomly.")
        print(f"Finding {n_meshes} latent code indices based on filenames containing: '{string_to_match}'")
        matched_indices = [i for i, s in enumerate(train_paths) if string_to_match.lower() in s.lower()]
        if len(matched_indices) < n_meshes:
            print(f"Warning: Only {len(matched_indices)} matches found, sampling all.")
            # Return all matching indices if not enough for n_meshes
            indices = np.array(matched_indices)
        else:
            indices = np.random.choice(matched_indices, n_meshes, replace=False)
    return indices

In [None]:
# Set up model loading and mesh creation parameters

# Load model config
config = load_config()
device = config.get("device", "cuda:0")
train_paths = config['list_mesh_paths']
#train_paths = config['test_paths']
all_vtk_files = [os.path.basename(f) for f in train_paths]

# Load model and latent codes
model, latent_ckpt, latent_codes = load_model_and_latents(MODEL_PATH, LC_PATH, config, device)

# Mesh creation params
recon_grid_origin = 1.0
n_pts_per_axis = 128
voxel_origin = (-recon_grid_origin, -recon_grid_origin, -recon_grid_origin)
voxel_size = (recon_grid_origin * 2) / (n_pts_per_axis - 1)
offset = np.array([0.0, 0.0, 0.0])
scale = 1.0
icp_transform = NumpyTransform(np.eye(4))
objects = 1

In [None]:
# Find the specimen closest to the median, and use that mesh and landmarks as a template

# Compute median latent
latent_codes = latent_codes.to(device)
median_latent = torch.median(latent_codes, dim=0).values  # shape: (latent_dim,)

# Normalize both sets of vectors to unit length
latent_norm = F.normalize(latent_codes, p=2, dim=1)
median_latent_norm = F.normalize(median_latent.unsqueeze(0), p=2, dim=1)

# Get cosine similarity and distance
cosine_similarities = torch.mm(latent_norm, median_latent_norm.T).squeeze()
cosine_distances = 1 - cosine_similarities

# Find specimen closest to median latent
med_idx = torch.argmin(cosine_distances).item()
print(f"Closest specimen to median latent: {all_vtk_files[med_idx]}")
med_fn = "../vertebrae_meshes/" + all_vtk_files[med_idx]
med_lm_fn = '../alignedLMs/' + os.path.splitext(all_vtk_files[med_idx])[0] + '.mrk.json'
print("Associated landmark file: ", med_lm_fn, "\033[0m\n")

# Save filename to text file
interp_dir = Path("interpedLMs")
interp_dir.mkdir(parents=True, exist_ok=True)
output_txt = interp_dir / "closest_specimen_to_median.txt"
with open(output_txt, "w") as f:
    f.write(f"Mesh: {med_fn}\n")
    f.write(f"Landmarks: {med_lm_fn}\n")
print(f"Saved closest specimen to median info to :{output_txt}")

# Prep mesh and landmarks for compatibility with NSM backend
med_mesh = vtk_to_o3d(med_fn)
med_mesh = o3d_to_pv(med_mesh)
med_lms, med_labels = load_mrk_json(med_lm_fn)
median_mesh, mesh_center, mesh_scale = normalize_mesh_to_unit_sphere(med_mesh)
normalized_lms = (med_lms - mesh_center) / mesh_scale
median_lms = project_landmarks_to_surface(normalized_lms, median_mesh.mesh if hasattr(median_mesh, "mesh") else median_mesh)

# Visualize landmarks on mesh in external viewer
visualize_landmarks_on_mesh(median_mesh, median_lms)

In [None]:
# Interpolate points from median specimen to random specimens

# Save resulting landmarks and meshes?
save_results = False # TO DO: Set to True or False

# Get indices of latents to interpolate
indices = get_indices(train_paths=all_vtk_files, random_sample=True, n_meshes=5, string_to_match=None)

# Main interpolation
interpolated_points = []
meshes = []
mesh_names = []

# Loop through randomly selected latent codes for interpolation
for i, idx in enumerate(indices):
    # Set out filename for landmarks
    vtk_path = Path(all_vtk_files[idx])
    stem = vtk_path.stem

    print(f"\n\033[92m{i}: {vtk_path.name}\033[0m")
    mesh_names.append((idx, vtk_path.name))

    out_lm_fn = interp_dir / f"interped_{stem}.mrk.json"
    out_mesh_fn = interp_dir / f"interped_{stem}.vtk"
    # Load latent codes by index
    latent = latent_codes[idx]
    # Create mesh from latent
    mesh_out = create_mesh(
            decoder=model, latent_vector=latent, n_pts_per_axis=n_pts_per_axis,
            voxel_origin=voxel_origin, voxel_size=voxel_size, path_original_mesh=None,
            offset=offset, scale=scale, icp_transform=icp_transform,
            objects=objects, verbose=False, device=device
        )
    mesh_out = mesh_out[0] if isinstance(mesh_out, list) else mesh_out
    mesh_out.resample_surface(clusters=20_000)

    # Normalize latent vectors
    latent1 = median_latent.cpu().numpy().astype(np.float64)
    latent2 = latent.cpu().numpy().astype(np.float64)

    # --- Interpolate landmarks (these are now in normalized space) ---
    interp_pts = interpolate_points(
        model=model,
        latent1=latent1,
        latent2=latent2,
        points1=median_lms,
        n_steps=100,
        smooth_type="taubin"
    )
    
    # Normalize mesh and landmarks to match NSM backend
    normalized_mesh, mesh_center, mesh_scale = normalize_mesh_to_unit_sphere(mesh_out)
    normalized_lms = (interp_pts - mesh_center) / mesh_scale
    
    # Append generated mesh and interpolated points to list
    meshes.append(normalized_mesh)
    interpolated_points.append(normalized_lms)

    # Optional: Save interpolated landmarks and generated meshes
    if save_results:
        print("Saving results for interpolated landmarks and meshes...")
        save_mrk_json(normalized_lms, med_labels, out_lm_fn)
        orig_lm_fn = Path("..") / "alignedLMs" / f"{stem}.mrk.json"
        dest_lm_fn = interp_dir / orig_lm_fn.name
        shutil.copy(orig_lm_fn, dest_lm_fn)
        normalized_mesh.save(out_mesh_fn)
        print(f"Saved mesh → {out_mesh_fn.name}")
        print(f"Copied original landmarks → {dest_lm_fn.name}")

In [None]:
# Visualize interpolated landmarks and generated mesh in interactive viewer
pv.set_jupyter_backend(None)  # disable notebook mode completely

# Pick index to inspect
idx_to_show = 1 # change index as needed

# Get corresponding mesh name
mesh_name = mesh_names[idx_to_show][1]  # assuming mesh_names stores (idx, name)
plot_title = "Median specimen: " + all_vtk_files[med_idx] + "\n vs \n" + "Interpolated specimen: " + mesh_name

# Show results in external viewer
visualize_landmark_interpolation_external(
    median_mesh, # original mesh
    meshes[idx_to_show], # generated mesh
    median_lms,  # original landmarks
    interpolated_points[idx_to_show],  # interpolated landmarks
    med_labels,  # labels
    plot_title = plot_title,
)

In [None]:
# Animate interpolated landmarks and generated mesh to gif

# Pick index to inspect
idx_to_show = 0  # change index as needed

# Get corresponding mesh name
mesh_name = mesh_names[idx_to_show][1]  # assuming mesh_names stores (idx, name)
plot_title = "Median specimen: " + all_vtk_files[med_idx] + "\n vs \n" + "Interpolated specimen: " + mesh_name

# Create GIF directory
gif_dir = interp_dir / "gifs"
gif_dir.mkdir(parents=True, exist_ok=True)

# Construct output path safely
output_fname = f"{idx_to_show}_{mesh_name}.gif"
output_fpath = gif_dir / output_fname

visualize_landmark_interpolation_external(
    median_mesh,
    meshes[idx_to_show],
    median_lms,
    interpolated_points[idx_to_show],
    med_labels,
    animate=True,
    output_path=output_fpath,
    off_screen=True,
    n_frames=240,
    plot_title = plot_title,
)