# Generating Explicit Priors

In [2]:
import json
import sys, glob, numpy as np, open3d as o3d, mesh_to_sdf as mts
from skimage import measure
from tqdm import tqdm
from stl import mesh as npmesh

sys.path.append("../../")

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
with open('../data/splits.json', 'r') as file:
    splits = json.load(file)

In [4]:
train_pollen = splits["train"]

In [None]:
import numpy as np
import trimesh
import os

def create_volumetric_mean_shape(base_dir,
                                 filenames,
                                 output_path,
                                 pitch=0.1,
                                 align_iterations=5):
    """
    Compute a mean shape from a list of meshes via volumetric (SDF) averaging.
    """

    # 1. Load meshes ----------------------------------------------------------
    if not filenames:
        print("The provided filename list is empty.")
        return

    print(f"Attempting to load {len(filenames)} meshes...")
    meshes = []
    for fname in filenames:
        full = os.path.join(base_dir, fname)
        if not os.path.exists(full):
            print(f"Warning: file not found → {full}")
            continue
        try:
            meshes.append(trimesh.load_mesh(full))
        except Exception as exc:
            print(f"Warning: failed to load {full} — {exc}")
    if not meshes:
        print("No valid meshes could be loaded. Aborting.")
        return
    print(f"Successfully loaded {len(meshes)} meshes.")

    # 2. Rigidly align meshes (ICP) ------------------------------------------
    print("Aligning meshes...")
    ref            = meshes[0]
    aligned        = [ref]
    for i, mesh in enumerate(meshes[1:], start=2):
        m = mesh.copy()
        for _ in range(align_iterations):
            T, _, _ = trimesh.registration.icp(m.vertices, ref, max_iterations=20)
            m.apply_transform(T)
        aligned.append(m)
        print(f"  aligned {i}/{len(meshes)}", end='\r')
    print("\nAlignment complete.")

    # 3. Build one dense voxel grid covering *all* aligned meshes ------------
    print("Defining a common grid...")
    # global (xmin, ymin, zmin) / (xmax, ymax, zmax)
    mins = np.min([m.bounds[0] for m in aligned], axis=0)
    maxs = np.max([m.bounds[1] for m in aligned], axis=0)
    overall_bounds = np.vstack((mins, maxs))              # shape (2, 3)

    # number of voxels along each axis
    dims = np.ceil((maxs - mins) / pitch).astype(int) + 1  # (nx, ny, nz)
    nx, ny, nz = dims

    # generate every voxel centre in a dense grid
    ix, iy, iz = np.indices(dims)
    grid_points = np.column_stack((
        mins[0] + ix.ravel() * pitch,
        mins[1] + iy.ravel() * pitch,
        mins[2] + iz.ravel() * pitch,
    ))
    print(f"Grid shape {tuple(dims)}, total voxels {len(grid_points):,}")

    # 4. Sample SDF on the common grid & sum ---------------------------------
    print(f"Sampling SDF for {len(aligned)} meshes (pitch = {pitch})...")
    summed_sdf = np.zeros(len(grid_points), dtype=np.float64)
    for i, mesh in enumerate(aligned, start=1):
        sdf = trimesh.proximity.signed_distance(mesh, grid_points)
        summed_sdf += sdf
        print(f"  sampled {i}/{len(aligned)}", end='\r')
    print("\nSampling complete.")

    # 5. Average SDF ---------------------------------------------------------
    averaged_sdf = summed_sdf / len(aligned)

    # 6. Marching-Cubes extraction ------------------------------------------
    print("Extracting mean shape via Marching Cubes...")
    matrix = averaged_sdf.reshape(tuple(dims))          # (nx, ny, nz)

    # ⬇️  no 'origin' kwarg here
    mean_mesh = trimesh.voxel.ops.matrix_to_marching_cubes(
        matrix=matrix,
        pitch=pitch
    )

    # shift the mesh from (0, 0, 0) to the true grid origin
    mean_mesh.apply_translation(mins)
    
    # write mesh to file
    print(f"Writing mean mesh to {output_path}...")
    mean_mesh.export(output_path)
    print("Mean shape creation complete.")
    

# ---------------------------------------------------------------------------
# Example call
mesh_directory = '../data/processed/meshes'
output_file    = './volumetric_mean_shape.stl'
voxel_pitch    = 0.1

create_volumetric_mean_shape(
    mesh_directory,
    train_pollen[:2],      # your list of mesh filenames
    output_file,
    pitch=voxel_pitch
)

Attempting to load 2 meshes...
Successfully loaded 2 meshes.
Aligning meshes...
  aligned 2/2
Alignment complete.
Defining a common grid...
Grid shape (np.int64(264), np.int64(298), np.int64(278)), total voxels 21,870,816
Sampling SDF for 2 meshes (pitch = 0.1)...


In [6]:
import os
import numpy as np
import trimesh


def mean_mesh_128(base_dir, filenames, output_path,
                  grid_size: int = 128,
                  align_iterations: int = 5,
                  verbose: bool = True):
    """
    Build a “mean” mesh from many meshes by SDF averaging on a fixed
    128³ voxel grid.

    Parameters
    ----------
    base_dir        : folder that contains the mesh files
    filenames       : list[str] – the mesh files (anything `trimesh` loads)
    output_path     : where to write the mean STL/PLY/OBJ/…
    grid_size       : number of voxels per edge (default 128)
    align_iterations: ICP iterations per mesh (0 ➜ skip ICP)
    """

    if not filenames:
        print("Nothing to do: empty filename list.")
        return

    # ------------------------------------------------------------------ 1
    # load meshes
    # ------------------------------------------------------------------
    meshes = []
    for name in filenames:
        path = os.path.join(base_dir, name)
        try:
            meshes.append(trimesh.load_mesh(path, force='mesh'))
            if verbose:
                print(f"loaded  {path}")
        except Exception as e:
            print(f"✗ {path}  ({e}) – skipped")

    if not meshes:
        print("No meshes could be loaded. Aborting.")
        return

    # ------------------------------------------------------------------ 2
    # rigid ICP alignment to the first mesh  (optional)
    # ------------------------------------------------------------------
    if align_iterations > 0:
        if verbose:
            print("Aligning meshes …")
        ref = meshes[0]
        aligned = [ref]
        for i, m in enumerate(meshes[1:], start=2):
            m = m.copy()
            for _ in range(align_iterations):
                T, _, _ = trimesh.registration.icp(
                    m.vertices, ref, max_iterations=20
                )
                m.apply_transform(T)
            aligned.append(m)
            if verbose:
                print(f"  • {i}/{len(meshes)}", end="\r")
        meshes = aligned
        if verbose:
            print("\nAlignment done.")

    # ------------------------------------------------------------------ 3
    # put *all* meshes into a single bounding box centred at the origin
    # ------------------------------------------------------------------
    all_bounds = np.vstack([m.bounds for m in meshes])
    mins       = all_bounds.min(axis=0)
    maxs       = all_bounds.max(axis=0)
    centre     = (mins + maxs) / 2.0
    extent     = (maxs - mins).max()           # longest edge
    pitch      = extent / (grid_size - 1)      # voxel size

    for m in meshes:
        m.apply_translation(-centre)           # now centred at (0,0,0)

    grid_min = -extent / 2.0                   # coordinate of voxel (0,0,0)

    if verbose:
        print(f"Grid   : {grid_size}³  pitch={pitch:.5g}")
        print(f"bbox   : {grid_min:.5g} … {grid_min + extent:.5g} (along each axis)")

    # ------------------------------------------------------------------ 4
    # build the dense 3-D grid of sampling points
    # ------------------------------------------------------------------
    nx = ny = nz = grid_size
    ix, iy, iz = np.indices((nx, ny, nz))
    grid_points = np.column_stack((
        grid_min + ix.ravel() * pitch,
        grid_min + iy.ravel() * pitch,
        grid_min + iz.ravel() * pitch,
    ))

    # ------------------------------------------------------------------ 5
    # sample & accumulate SDFs
    # ------------------------------------------------------------------
    if verbose:
        print("Sampling signed-distance fields …")
    sdf_sum = np.zeros(len(grid_points), dtype=np.float64)
    for i, m in enumerate(meshes, start=1):
        sdf_sum += trimesh.proximity.signed_distance(m, grid_points)
        if verbose:
            print(f"  • {i}/{len(meshes)}", end="\r")
    if verbose:
        print("\nSDF sampling done.")

    mean_sdf = sdf_sum / len(meshes)
    volume   = mean_sdf.reshape((nx, ny, nz))

    # ------------------------------------------------------------------ 6
    # Marching Cubes
    # ------------------------------------------------------------------
    if verbose:
        print("Running Marching Cubes …")
    try:
        mean_mesh = trimesh.voxel.ops.matrix_to_marching_cubes(
            matrix=volume,
            pitch=pitch,
            origin=[grid_min]*3               # newest Trimesh versions
        )
    except TypeError:                         # older Trimesh: no 'origin'
        mean_mesh = trimesh.voxel.ops.matrix_to_marching_cubes(
            matrix=volume,
            pitch=pitch
        )
        mean_mesh.apply_translation([grid_min]*3)

    # ------------------------------------------------------------------ 7
    # export
    # ------------------------------------------------------------------
    if mean_mesh.is_empty:
        print("⚠️  Mean mesh is empty – try a finer grid or better alignment.")
    else:
        mean_mesh.export(output_path)
        print(f"✅  Mean mesh written to  {output_path}")


# ---------------------------------------------------------------------- #
# example call
# ---------------------------------------------------------------------- #
mesh_dir   = "../data/processed/meshes"
out_file   = "./mean_shape_128.stl"
voxel_grid = 32                  # keep default
icp_iters  = 5

mean_mesh_128(
    mesh_dir,
    train_pollen[:10],            # list of mesh filenames
    out_file,
    grid_size=voxel_grid,
    align_iterations=icp_iters
)


loaded  ../data/processed/meshes\17817_Creeping_buttercup_Ranunculus_repens_pollen_grain.stl
loaded  ../data/processed/meshes\17788_Gorse_Ulex_europaeus_pollen_grain.stl
loaded  ../data/processed/meshes\21600_Common_haircap_Polytrichum_commune_spore.stl
loaded  ../data/processed/meshes\20941_Lingonberry_Vaccinium_vitis-idaea_pollen_grain.stl
loaded  ../data/processed/meshes\17883_Evening_primrose_Oenothera_fruticosa_pollen_grain.stl
loaded  ../data/processed/meshes\17779_Strawberry_Fragaria_ananassa_pollen_grain.stl
loaded  ../data/processed/meshes\17886_Common_wheat_Triticum_aestivan_pollen_grain.stl
loaded  ../data/processed/meshes\17811_White_goosefoot_Chenopodium_album_pollen_grain.stl
loaded  ../data/processed/meshes\17843_Shining_pondweed_Potamogeton_lucens_pollen_grain.stl
loaded  ../data/processed/meshes\21103_Round-headed_Rampion_Phyteuma_tenerum_pollen_grain.stl
Aligning meshes …
  • 10/10
Alignment done.
Grid   : 32³  pitch=0.98123
bbox   : -15.209 … 15.209 (along each axis)

MemoryError: Unable to allocate 4.38 GiB for an array with shape (65366168, 3, 3) and data type float64