In [2]:
import ants

from deep_ccf_registration.datasets.slice_dataset import ExperimentMetadata, SliceDataset, \
    SliceOrientation
import json
from aind_smartspim_transform_utils.CoordinateTransform import CoordinateTransform

with open('/Users/adam.amster/Downloads/SmartSPIM_806624_2025-08-27_15-42-18_stitched_2025-08-28_13-34-06/dataset_meta.json') as f:
    dataset_meta = json.load(f)
dataset_meta = [ExperimentMetadata.model_validate(x) for x in dataset_meta]

experiment_meta = dataset_meta[0]

coord_transform = CoordinateTransform(
    name='smartspim_lca',
    dataset_transforms={
        'points_to_ccf': [
            str(experiment_meta.ls_to_template_affine_matrix_path),
            str(experiment_meta.ls_to_template_inverse_warp_path),
        ]
    },
    acquisition_axes=experiment_meta.axes,
    image_metadata={'shape': experiment_meta.registered_shape}
)

In [14]:
coord_transform.ls_template

ANTsImage (RAS)
	 Pixel Type : float (float32)
	 Components : 1
	 Dimensions : (576, 648, 440)
	 Spacing    : (0.025, 0.025, 0.025)
	 Origin     : (-1.5114, -1.5, 1.5)
	 Direction  : [ 1.  0.  0.  0.  1.  0.  0.  0. -1.]

In [12]:
ants.image_read(str(experiment_meta.ls_to_template_inverse_warp_path))

ANTsImage (RAS)
	 Pixel Type : float (float32)
	 Components : 3
	 Dimensions : (576, 648, 440)
	 Spacing    : (0.025, 0.025, 0.025)
	 Origin     : (-1.5114, -1.5, 1.5)
	 Direction  : [ 1.  0.  0.  0.  1.  0.  0.  0. -1.]

In [2]:
from deep_ccf_registration.datasets.slice_dataset import _create_coordinate_dataframe

slice_axis = slice_dataset._get_slice_axis(axes=experiment_meta.axes)

height = experiment_meta.registered_shape[[x.dimension for x in experiment_meta.axes if x.name != slice_axis.name][0]]
width = experiment_meta.registered_shape[[x.dimension for x in experiment_meta.axes if x.name != slice_axis.name][1]]

point_grid = _create_coordinate_dataframe(
    height=height,
    width=width,
    fixed_index_value=200
)

In [19]:
ls_template_points_ants = coord_transform.forward_transform(
    points=point_grid,
    points_resolution=list(experiment_meta.registered_resolution)
)

## SimpleITK using in-memory warp

In [3]:
import SimpleITK as sitk

def load_warp_sitk():
    warp_img = sitk.ReadImage(str(experiment_meta.ls_to_template_inverse_warp_path))
    warp_img = sitk.Cast(warp_img, sitk.sitkVectorFloat64)
    inverse_warp = sitk.DisplacementFieldTransform(warp_img)
    return inverse_warp
inverse_warp = load_warp_sitk()
inverse_warp

<SimpleITK.SimpleITK.DisplacementFieldTransform; proxy of <Swig Object of type 'itk::simple::DisplacementFieldTransform *' at 0x142cf2850> >

In [39]:
def load_warp_numpy():
    warp_img = ants.image_read(str(experiment_meta.ls_to_template_inverse_warp_path))
    return warp_img.numpy()
inverse_warp_numpy = load_warp_numpy()

In [27]:
from pathlib import Path
from aind_smartspim_transform_utils.utils import utils
import numpy as np
import pandas as pd

In [29]:
def apply_transform_sitk(warp: sitk.DisplacementFieldTransform, affine_path: Path):
    points = coord_transform.prepare_points_for_forward_transform(
        points=point_grid,
        points_resolution=list(experiment_meta.registered_resolution)
    )
    affine_transformed_points = utils.apply_transforms_to_points(
        ants_pts=points,
        transforms=[str(affine_path)],
        invert=(True,)
    )
    transformed_points = np.zeros_like(affine_transformed_points)

    for i, point in enumerate(affine_transformed_points):
        transformed_points[i] = warp.TransformPoint(point.tolist())

    transformed_points = utils.convert_from_ants_space(
        coord_transform.ls_template_info, transformed_points
    )

    transformed_df = pd.DataFrame(
        transformed_points, columns=["ML", "AP", "DV"]
    )
    return transformed_df
ls_template_points_sitk = apply_transform_sitk(warp=inverse_warp, affine_path=experiment_meta.ls_to_template_affine_matrix_path)

In [30]:
np.testing.assert_allclose(ls_template_points_ants.values, ls_template_points_sitk.values)

In [43]:
from scipy.ndimage import map_coordinates


def apply_transform_scipy(warp: np.ndarray, affine_path: Path):
    points = coord_transform.prepare_points_for_forward_transform(
        points=point_grid,
        points_resolution=list(experiment_meta.registered_resolution)
    )
    affine_transformed_points = utils.apply_transforms_to_points(
        ants_pts=points,
        transforms=[str(affine_path)],
        invert=(True,)
    )
    displacements = np.zeros((len(affine_transformed_points), 3))

    for component in range(3):
        displacements[:, component] = map_coordinates(
            warp[:, :, :, component],
            affine_transformed_points.T,
            order=1,
            mode='nearest',
            prefilter=False
        )

    transformed_points = affine_transformed_points - displacements

    transformed_points = utils.convert_from_ants_space(
        coord_transform.ls_template_info, transformed_points
    )

    transformed_df = pd.DataFrame(
        transformed_points, columns=["ML", "AP", "DV"]
    )
    return transformed_df

In [45]:
ls_template_points_numpy = apply_transform_scipy(warp=inverse_warp_numpy, affine_path=experiment_meta.ls_to_template_affine_matrix_path)

In [46]:
np.testing.assert_allclose(ls_template_points_ants.values, ls_template_points_numpy.values)