# Process Surfaces with Correspondence for PCA

This notebook processes aligned surface files with correspondence data to prepare inputs for PCA analysis.

**Workflow:**
1. Read VTK files from `surfaces_aligned_correspond/` subdirectory
2. Read matching VTP files from `surfaces/` subdirectory
3. For each point in the correspondence mesh, extract its `vtkOriginalPointIds` field
4. Use those IDs as indices to look up point positions in the original surface mesh
5. Average the surface mesh point positions for each correspondence point
6. Replace the correspondence point position with that average
7. Save results to `pca_inputs/` subdirectory

**Purpose:**
The correspondence algorithm creates a dense mesh where each point has a list of original point IDs from the surface mesh. This notebook averages the original surface positions referenced by those IDs, creating a smoothed mesh that better represents the underlying surface geometry for PCA analysis.


In [None]:
from pathlib import Path

import itk
import pyvista as pv
from physiomotion4d.register_model_to_model_masks import RegisterModelsDistanceMaps

# Enable interactive plotting
pv.set_jupyter_backend('trame')


## 1. Setup Directories and Find Files


In [None]:
# Define directories
correspond_dir = Path('surfaces_aligned_correspond')
surfaces_dir = Path('surfaces')
output_dir = Path('pca_inputs')

# Create output directory
output_dir.mkdir(exist_ok=True)

# Find all VTK/VTP files in correspondence directory
correspond_files = sorted(correspond_dir.glob('*.vtp'))

print(f"Found {len(correspond_files)} files in {correspond_dir}/")
for f in correspond_files:
    print(f"  {f.name}")


## 2. Define Processing Function

This function:
1. Loads the correspondence mesh (with vtkOriginalPointIds)
2. Loads the original surface mesh
3. For each point in correspondence mesh, gets its list of original IDs
4. Looks up those IDs as indices into the surface mesh points
5. Averages the surface mesh point positions
6. Replaces the correspondence point with that average
7. Creates a new mesh with averaged positions


In [None]:
def process_correspondence_mesh(
    correspond_file,
    surfaces_dir,
    output_dir,
    ref_image,
):
    """
    Process a correspondence mesh by averaging surface positions based on vtkOriginalPointIds.
    """

    # Load the correspondence mesh
    correspond_mesh = pv.read(correspond_file)

    base_name = correspond_file.stem.replace('_correspond', '')
    surface_file = surfaces_dir / f"{base_name}.vtp"
    if not surface_file.exists():
        print(f"  ERROR: Could not find matching surface file for {base_name}")
        return None
    surface_mesh = pv.read(surface_file)

    registrar = RegisterModelsDistanceMaps(
        moving_mesh=correspond_mesh,
        fixed_mesh=surface_mesh,
        reference_image=ref_image,
    )

    result = registrar.register(mode="deformable")

    registered_mesh = result["moving_mesh"]

    # Save processed mesh
    output_file = output_dir / f"{base_name}.vtk"
    registered_mesh.save(output_file)
    print(f"  Saved to: {output_file}")

    return registered_mesh


## 3. Process All Correspondence Files


In [None]:
# Process all files
processed_meshes = {}
failed_files = []

template_mesh = pv.read("average_surface.vtp")
bounds = template_mesh.bounds
xmin = bounds[0]
xmax = bounds[1]
ymin = bounds[2]
ymax = bounds[3]
zmin = bounds[4]
zmax = bounds[5]
img_size = [300, 300, 300]
origin = [xmin - (xmax-xmin)/3, ymin - (ymax-ymin)/3, zmin - (zmax-zmin)/3]
extent = [xmax + (xmax-xmin)/3, ymax + (ymax-ymin)/3, zmax + (zmax-zmin)/3]
spacing = [
    (extent[0] - origin[0]) / img_size[0],
    (extent[1] - origin[1]) / img_size[1],
    (extent[2] - origin[2]) / img_size[2]
]
region = itk.ImageRegion[3]()
region.SetSize(img_size)
ref_image = itk.Image[itk.F, 3].New()
ref_image.SetRegions(region)
ref_image.SetOrigin(origin)
ref_image.SetSpacing(spacing)
ref_image.Allocate()

if len(correspond_files) == 0:
    print("WARNING: No files found in surfaces_aligned_correspond directory!")
    print("Please run the correspondence algorithm (e.g., SlicerSALT) first.")
else:
    for correspond_file in correspond_files:
        try:
            result = process_correspondence_mesh(
                correspond_file,
                surfaces_dir,
                output_dir,
                ref_image,
            )
            if result is not None:
                processed_meshes[correspond_file.stem] = result
            else:
                failed_files.append(correspond_file.name)
        except Exception as e:
            print(f"\n  ERROR processing {correspond_file.name}: {str(e)}")
            failed_files.append(correspond_file.name)

    print(f"\n{'='*70}")
    print("Processing Complete!")
    print(f"{'='*70}")
    print(f"  Successfully processed: {len(processed_meshes)} files")
    print(f"  Failed: {len(failed_files)} files")
    if failed_files:
        print(f"  Failed files: {', '.join(failed_files)}")
