
# Functional → Anatomy Registration Prototype

This notebook exercises the functional → anatomy registration helper.
Run inside the `fireantsGH` environment.

Registration parameters are derived automatically from preprocessing metadata; adjust the
constants in the first code cell if you need to override the defaults.


In [None]:

from pathlib import Path
import json

import numpy as np
import pandas as pd
import tifffile as tiff

from social_imaging_scripts.registration.functional_to_anatomy_ants import register_planes_pass1

# Configure paths
ANATOMY_STACK_PATH = Path("/mnt/f/johannes/pipelineOut/L395_f11/02_reg/00_preprocessing/2p_anatomy/L395_f11_anatomy_stack.tif")
ANATOMY_METADATA_PATH = Path("/mnt/f/johannes/pipelineOut/L395_f11/02_reg/00_preprocessing/2p_anatomy/L395_f11_anatomy_metadata.json")
FUNCTIONAL_PROJECTION_PATH = Path("/mnt/f/johannes/pipelineOut/L395_f11/02_reg/00_preprocessing/2p_functional/02_motionCorrected/projections/L395_f11_avg_projections.tif")
FUNCTIONAL_METADATA_PATH = Path("/mnt/f/johannes/pipelineOut/L395_f11/02_reg/00_preprocessing/2p_functional/01_individualPlanes/L395_f11_preprocessing_metadata.json")
PLANE_INDICES = [0, 1, 2, 3, 4]


def _read_pixel_size(metadata_path: Path) -> tuple[float, float]:
    payload = json.loads(metadata_path.read_text())
    pixels = payload.get("pixel_size_xy_um")
    if not isinstance(pixels, (list, tuple)) or len(pixels) != 2:
        raise ValueError(f"Missing pixel_size_xy_um in {metadata_path}")
    return float(pixels[0]), float(pixels[1])


functional_pixel_size_um = _read_pixel_size(FUNCTIONAL_METADATA_PATH)
anatomy_pixel_size_um = _read_pixel_size(ANATOMY_METADATA_PATH)
expected_scale = float(
    (functional_pixel_size_um[0] / anatomy_pixel_size_um[0] + functional_pixel_size_um[1] / anatomy_pixel_size_um[1])
    / 2.0
)
scale_window = 0.0
REG_SCALE_RANGE = (
    max(0.1, expected_scale * (1.0 - scale_window)),
    expected_scale * (1.0 + scale_window),
)
REG_N_SCALES = 1
REG_Z_STRIDE_COARSE = 8
REG_Z_REFINE_RADIUS = 10
REG_GAUSSIAN_SIGMA = 1.0
REG_EARLY_STOP_SCORE = 0.995  # Set to None to disable early exit
REG_PROGRESS = True            # Set to True for per-plane progress messages

print("Functional pixel size (µm):", functional_pixel_size_um)
print("Anatomy pixel size (µm):", anatomy_pixel_size_um)
print(
    f"Expected scale ≈ {expected_scale:.3f}; search range "
    f"[{REG_SCALE_RANGE[0]:.3f}, {REG_SCALE_RANGE[1]:.3f}]"
)

print("Loading anatomy stack…", ANATOMY_STACK_PATH)
anatomy_stack = tiff.imread(str(ANATOMY_STACK_PATH)).astype(np.float32)
print("Anatomy stack shape:", anatomy_stack.shape)

print("Loading functional projections…", FUNCTIONAL_PROJECTION_PATH)
functional_stack = tiff.imread(str(FUNCTIONAL_PROJECTION_PATH)).astype(np.float32)
functional_planes = [functional_stack[idx] for idx in PLANE_INDICES]
print(f"Loaded {len(functional_planes)} functional planes with shape {functional_planes[0].shape}")


In [None]:

results = register_planes_pass1(
    anatomy_stack=anatomy_stack,
    functional_planes=functional_planes,
    plane_indices=PLANE_INDICES,
    downscale_if_needed=False,
    scale_range=REG_SCALE_RANGE,
    n_scales=REG_N_SCALES,
    z_stride_coarse=REG_Z_STRIDE_COARSE,
    z_refine_radius=REG_Z_REFINE_RADIUS,
    gaussian_sigma=REG_GAUSSIAN_SIGMA,
    early_stop_score=REG_EARLY_STOP_SCORE,
    progress=(print if REG_PROGRESS else None),
)

summary_rows = []
for res in results:
    summary_rows.append({
        "plane_index": res.plane_index,
        "best_z": res.best_z,
        "ncc": res.ncc,
        "success": res.success,
        "message": res.message,
    })

summary_df = pd.DataFrame(summary_rows)
summary_df

In [None]:

import ipywidgets as widgets
import matplotlib.pyplot as plt
from IPython.display import display

plane_dropdown = widgets.Dropdown(options=PLANE_INDICES, description="Plane")
z_slider = widgets.IntSlider(value=0, min=0, max=anatomy_stack.shape[0]-1, description="Z slice")
output_view = widgets.Output()

plane_data = {res.plane_index: res for res in results if res.success}

@output_view.capture(clear_output=True)
def _render(_=None):
    plane_idx = plane_dropdown.value
    data = plane_data.get(plane_idx)
    if data is None:
        print("No successful registration for plane", plane_idx)
        return
    warped = data.warped_volume
    z_max = min(warped.shape[0]-1, anatomy_stack.shape[0]-1)
    if z_slider.max != z_max:
        z_slider.max = z_max
    z = min(z_slider.value, z_max)
    z_slider.value = z
    warped_slice = warped[z]
    fixed_slice = anatomy_stack[z]

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    axes[0].imshow(fixed_slice, cmap='gray')
    axes[0].set_title(f"Anatomy z={z}")
    axes[0].axis('off')

    axes[1].imshow(warped_slice, cmap='gray')
    axes[1].set_title(f"Warped plane {plane_idx} at z={z}")
    axes[1].axis('off')

    plt.tight_layout()
    plt.show()

plane_dropdown.observe(_render, names='value')
z_slider.observe(_render, names='value')

controls = widgets.VBox([plane_dropdown, z_slider])
display(widgets.VBox([controls, output_view]))
_render()


## Notes
Add refinement experiments here once pass 2 is implemented.