In [110]:
import ants

from deep_ccf_registration.datasets.slice_dataset import ExperimentMetadata, SliceDataset, \
    SliceOrientation
import json
from aind_smartspim_transform_utils.CoordinateTransform import CoordinateTransform

with open('/Users/adam.amster/Downloads/SmartSPIM_806624_2025-08-27_15-42-18_stitched_2025-08-28_13-34-06/dataset_meta.json') as f:
    dataset_meta = json.load(f)
dataset_meta = [ExperimentMetadata.model_validate(x) for x in dataset_meta]

experiment_meta = dataset_meta[0]

In [70]:


coord_transform = CoordinateTransform(
    name='smartspim_lca',
    dataset_transforms={
        'points_to_ccf': [
            str(experiment_meta.ls_to_template_affine_matrix_path),
            str(experiment_meta.ls_to_template_inverse_warp_path),
        ]
    },
    acquisition_axes=experiment_meta.axes,
    image_metadata={'shape': experiment_meta.registered_shape},
    ls_template=ants.image_read('/Users/adam.amster/.transform_utils/transform_utils/smartspim_lca/template/smartspim_lca_template_25.nii.gz')
)

In [14]:
coord_transform.ls_template

ANTsImage (RAS)
	 Pixel Type : float (float32)
	 Components : 1
	 Dimensions : (576, 648, 440)
	 Spacing    : (0.025, 0.025, 0.025)
	 Origin     : (-1.5114, -1.5, 1.5)
	 Direction  : [ 1.  0.  0.  0.  1.  0.  0.  0. -1.]

In [12]:
ants.image_read(str(experiment_meta.ls_to_template_inverse_warp_path))

ANTsImage (RAS)
	 Pixel Type : float (float32)
	 Components : 3
	 Dimensions : (576, 648, 440)
	 Spacing    : (0.025, 0.025, 0.025)
	 Origin     : (-1.5114, -1.5, 1.5)
	 Direction  : [ 1.  0.  0.  0.  1.  0.  0.  0. -1.]

In [71]:
from deep_ccf_registration.datasets.slice_dataset import _create_coordinate_dataframe

height = experiment_meta.registered_shape[[x.dimension for x in experiment_meta.axes if x.name.value != 'Z'][0]]
width = experiment_meta.registered_shape[[x.dimension for x in experiment_meta.axes if x.name.value != 'Z'][1]]

point_grid = _create_coordinate_dataframe(
    height=height,
    width=width,
    fixed_index_value=200
)

In [111]:
ls_template_points_ants = coord_transform.forward_transform(
    points=point_grid,
    points_resolution=list(experiment_meta.registered_resolution)
)

In [57]:
import pandas as pd


def sample_ls_template_at_points(ls_template_points: pd.DataFrame, ls_template: ants.ANTsImage, interpolation='nearest'):
    ls_template_points = ls_template_points.values

    # Check bounds
    valid_mask = (
        (ls_template_points[:, 0] >= 0) & (ls_template_points[:, 0] < ls_template.shape[0]) &
        (ls_template_points[:, 1] >= 0) & (ls_template_points[:, 1] < ls_template.shape[1]) &
        (ls_template_points[:, 2] >= 0) & (ls_template_points[:, 2] < ls_template.shape[2])
    )

    values = np.zeros(len(ls_template_points))

    if interpolation == 'nearest':
        # Round to nearest integer
        ls_template_points = np.floor(ls_template_points[valid_mask]).astype(int)
        values[valid_mask] = ls_template[
            ls_template_points[:, 0],
            ls_template_points[:, 1],
            ls_template_points[:, 2]
        ]
    else:
        raise NotImplementedError

    return values, valid_mask

In [61]:
from deep_ccf_registration.utils.utils import visualize_alignment, sample_template_at_points
import tensorstore
import torch
import matplotlib.pyplot as plt
import numpy as np

template_points = torch.tensor(ls_template_points_ants.values)

raw = torch.tensor(tensorstore.open('file:///Users/adam.amster/Downloads/SmartSPIM_806624_2025-08-27_15-42-18_stitched_2025-08-28_13-34-06/Ex_639_Em_680.zarr/3', read=True).result()[:].read().result())

ls_template_on_input = sample_ls_template_at_points(ls_template_points=ls_template_points_ants, ls_template=coord_transform.ls_template, interpolation='nearest')[0].reshape((raw.shape[-2], raw.shape[-1]))

fig, ax = plt.subplots(figsize=(10, 5))
ax.imshow(ls_template_on_input, cmap='gray')
fig.savefig('/tmp/foo.png')

# visualize_alignment(
#     input_slice=raw[0, 0, 200],
#     template_points=ls_template_on_input,
#     template=coord_transform.ls_template
# )

ValueError: setting an array element with a sequence.

## SimpleITK using in-memory warp

In [3]:
import SimpleITK as sitk

def load_warp_sitk():
    warp_img = sitk.ReadImage(str(experiment_meta.ls_to_template_inverse_warp_path))
    warp_img = sitk.Cast(warp_img, sitk.sitkVectorFloat64)
    inverse_warp = sitk.DisplacementFieldTransform(warp_img)
    return inverse_warp
inverse_warp = load_warp_sitk()
inverse_warp

<SimpleITK.SimpleITK.DisplacementFieldTransform; proxy of <Swig Object of type 'itk::simple::DisplacementFieldTransform *' at 0x142cf2850> >

In [138]:
def load_warp_numpy():
    warp_img = ants.image_read(str(experiment_meta.ls_to_template_inverse_warp_path))
    return warp_img.numpy()
inverse_warp_numpy = load_warp_numpy()

In [27]:
from pathlib import Path
from aind_smartspim_transform_utils.utils import utils
import numpy as np
import pandas as pd

In [23]:
def apply_transform_sitk(warp: sitk.DisplacementFieldTransform, affine_path: Path):
    points = coord_transform.prepare_points_for_forward_transform(
        points=point_grid,
        points_resolution=list(experiment_meta.registered_resolution)
    )
    affine_transformed_points = utils.apply_transforms_to_points(
        ants_pts=points,
        transforms=[str(affine_path)],
        invert=(True,)
    )
    transformed_points = np.zeros_like(affine_transformed_points)

    for i, point in enumerate(affine_transformed_points):
        transformed_points[i] = warp.TransformPoint(point.tolist())

    transformed_points = utils.convert_from_ants_space(
        coord_transform.ls_template_info, transformed_points
    )

    transformed_df = pd.DataFrame(
        transformed_points, columns=["ML", "AP", "DV"]
    )
    return transformed_df
ls_template_points_sitk = apply_transform_sitk(warp=inverse_warp, affine_path=experiment_meta.ls_to_template_affine_matrix_path)

NameError: name 'sitk' is not defined

In [30]:
np.testing.assert_allclose(ls_template_points_ants.values, ls_template_points_sitk.values)

In [133]:
from scipy.ndimage import map_coordinates
import numpy as np
from pathlib import Path
from aind_smartspim_transform_utils.utils import utils
import pandas as pd

def apply_transform_scipy(warp: np.ndarray, affine_path: Path):
    points = coord_transform.prepare_points_for_forward_transform(
        points=point_grid,
        points_resolution=list(experiment_meta.registered_resolution)
    )
    affine_transformed_points = utils.apply_transforms_to_points(
        ants_pts=points,
        transforms=[str(affine_path)],
        invert=(True,)
    )

    # Convert to voxel indices for sampling
    voxel_indices = utils.convert_from_ants_space(
        coord_transform.ls_template_info,
        affine_transformed_points
    )

    displacements = np.zeros((len(affine_transformed_points), 3))

    for component in range(3):
        displacements[:, component] = map_coordinates(
            warp[:, :, :, component],
            voxel_indices.T,
            order=1,
            mode='nearest',
            prefilter=False
        )

    # ADD the displacements, don't subtract!
    transformed_points = affine_transformed_points + displacements  # Changed from -

    transformed_points = utils.convert_from_ants_space(
        coord_transform.ls_template_info, transformed_points
    )

    transformed_df = pd.DataFrame(
        transformed_points, columns=["ML", "AP", "DV"]
    )
    return transformed_df

In [134]:
ls_template_points_numpy = apply_transform_scipy(warp=inverse_warp_numpy, affine_path=experiment_meta.ls_to_template_affine_matrix_path)

In [135]:
np.testing.assert_allclose(ls_template_points_ants.values, ls_template_points_numpy.values)

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 1398615 / 1589760 (88%)
Max absolute difference among violations: 0.1731825
Max relative difference among violations: 0.22382989
 ACTUAL: array([[178.01631 , 619.24646 , 367.4785  ],
       [178.0243  , 619.2837  , 366.85626 ],
       [178.03226 , 619.3209  , 366.234   ],...
 DESIRED: array([[178.016512, 619.246393, 367.478275],
       [178.0245  , 619.283611, 366.856041],
       [178.032471, 619.320853, 366.233784],...

In [132]:
# Compare the full pipeline step by step
test_res = [16, 14.4, 14.4]
points = coord_transform.prepare_points_for_forward_transform(
    points=point_grid[:10],  # Just a few points
    points_resolution=test_res
)

# Step 1: Affine
affine_only = utils.apply_transforms_to_points(
    points,
    [str(experiment_meta.ls_to_template_affine_matrix_path)],
    invert=(True,)
)

# Step 2: Get voxel indices
voxel_indices = utils.convert_from_ants_space(
    coord_transform.ls_template_info,
    affine_only
)

# Step 3: Sample displacements
scipy_displacements = np.zeros((len(affine_only), 3))
for component in range(3):
    scipy_displacements[:, component] = map_coordinates(
        inverse_warp_numpy[:, :, :, component],
        voxel_indices.T,
        order=1,
        mode='nearest',
        prefilter=False
    )

# Step 4: Apply displacements
scipy_after_warp = affine_only - scipy_displacements

# Step 5: Convert from ANTs space
scipy_final = utils.convert_from_ants_space(
    coord_transform.ls_template_info,
    scipy_after_warp
)

# Compare with ANTs
ants_result = utils.apply_transforms_to_points(
    points,
    coord_transform.dataset_transforms["points_to_ccf"],
    invert=(True, False)
)

# ALSO convert ANTs result from ANTs space for fair comparison
ants_final = utils.convert_from_ants_space(
    coord_transform.ls_template_info,
    ants_result
)

print("After affine (physical):", affine_only[0])
print("Voxel indices:", voxel_indices[0])
print("Displacement:", scipy_displacements[0])
print("After applying displacement (physical):", scipy_after_warp[0])
print("Scipy final (voxel):", scipy_final[0])
print("ANTs result (physical):", ants_result[0])
print("ANTs final (voxel):", ants_final[0])
print(f"\nDo they match? {np.allclose(scipy_final, ants_final, atol=1e-5)}")

# Check if the issue is in the final conversion
print(f"\nBefore final conversion, match? {np.allclose(scipy_after_warp, ants_result, atol=1e-5)}")

After affine (physical): [ 2.9388764 13.98133   -7.687716 ]
Voxel indices: [178.01065 619.2532  367.5086 ]
Displacement: [ 0.00014184 -0.00016771  0.00075321]
After applying displacement (physical): [ 2.93873455 13.98149762 -7.68846922]
Scipy final (voxel): [178.00497805 619.25989571 367.53876341]
ANTs result (physical): [ 2.9390182 13.981162  -7.6869626]
ANTs final (voxel): [178.01631 619.24646 367.4785 ]

Do they match? False

Before final conversion, match? False


## compare ants.image_read runtime

In [3]:
ants.image_read(str(experiment_meta.ls_to_template_inverse_warp_path))

ANTsImage (RAS)
	 Pixel Type : float (float32)
	 Components : 3
	 Dimensions : (576, 648, 440)
	 Spacing    : (0.025, 0.025, 0.025)
	 Origin     : (-1.5114, -1.5, 1.5)
	 Direction  : [ 1.  0.  0.  0.  1.  0.  0.  0. -1.]

In [4]:
ants.image_read(str(experiment_meta.ls_to_template_inverse_warp_path), pixeltype=None)

ANTsImage (RAS)
	 Pixel Type : float (float32)
	 Components : 3
	 Dimensions : (576, 648, 440)
	 Spacing    : (0.025, 0.025, 0.025)
	 Origin     : (-1.5114, -1.5, 1.5)
	 Direction  : [ 1.  0.  0.  0.  1.  0.  0.  0. -1.]

In [5]:
import nibabel

In [6]:
nibabel.load(str(experiment_meta.ls_to_template_inverse_warp_path)).get_fdata()

array([[[[[0., 0., 0.]],

         [[0., 0., 0.]],

         [[0., 0., 0.]],

         ...,

         [[0., 0., 0.]],

         [[0., 0., 0.]],

         [[0., 0., 0.]]],


        [[[0., 0., 0.]],

         [[0., 0., 0.]],

         [[0., 0., 0.]],

         ...,

         [[0., 0., 0.]],

         [[0., 0., 0.]],

         [[0., 0., 0.]]],


        [[[0., 0., 0.]],

         [[0., 0., 0.]],

         [[0., 0., 0.]],

         ...,

         [[0., 0., 0.]],

         [[0., 0., 0.]],

         [[0., 0., 0.]]],


        ...,


        [[[0., 0., 0.]],

         [[0., 0., 0.]],

         [[0., 0., 0.]],

         ...,

         [[0., 0., 0.]],

         [[0., 0., 0.]],

         [[0., 0., 0.]]],


        [[[0., 0., 0.]],

         [[0., 0., 0.]],

         [[0., 0., 0.]],

         ...,

         [[0., 0., 0.]],

         [[0., 0., 0.]],

         [[0., 0., 0.]]],


        [[[0., 0., 0.]],

         [[0., 0., 0.]],

         [[0., 0., 0.]],

         ...,

         [[0., 0., 0.]],

 

In [137]:
import tensorstore

In [141]:
compression_codec = {
    "name": "blosc",
    "configuration": {
        "cname": "zstd",
        "clevel": 5,
        "shuffle": "shuffle",
    },
}

ts = tensorstore.open({'driver': 'zarr3', 'kvstore': {'driver': 'file', 'path': '/Users/adam.amster/Downloads/SmartSPIM_806624_2025-08-27_15-42-18_stitched_2025-08-28_13-34-06/ls_to_template_SyN_1InverseWarp.zarr'}},             chunk_layout=tensorstore.ChunkLayout(
                chunk=tensorstore.ChunkLayout.Grid(shape=(576, 648, 440, 1))
            ), shape=(576, 648, 440, 3), create=True, dtype=inverse_warp_numpy.dtype, delete_existing=True,
                codec=tensorstore.CodecSpec({"driver": "zarr3", "codecs": [compression_codec]}),
                      ).result()
ts[:] = inverse_warp_numpy

In [14]:
ts[:, :, :, 0].read().result()

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

In [140]:
compression_codec = {
    "name": "blosc",
    "configuration": {
        "cname": "zstd",
        "clevel": 5,
        "shuffle": "shuffle",
    },
}
ts = tensorstore.open({'driver': 'zarr3', 'kvstore': {'driver': 'file', 'path': '/Users/adam.amster/Downloads/SmartSPIM_806624_2025-08-27_15-42-18_stitched_2025-08-28_13-34-06/ls_to_template_SyN_1InverseWarp_fp16.zarr'}},             chunk_layout=tensorstore.ChunkLayout(
                # chunk=tensorstore.ChunkLayout.Grid(shape=(576, 648, 440, 3))
            ), shape=(576, 648, 440, 3), create=True, dtype=tensorstore.float16, delete_existing=True,
                              codec=tensorstore.CodecSpec({"driver": "zarr3", "codecs": [compression_codec]}),
                      ).result()
ts[:] = inverse_warp_numpy.astype('float16')

In [16]:
ts[:, :, :, 0].read().result()

array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0.

In [19]:
from deep_ccf_registration.datasets.slice_dataset import SliceDataset
from pathlib import Path

slice_dataset = SliceDataset(
    dataset_meta=dataset_meta,
    orientation=SliceOrientation.SAGITTAL,
    ls_template_path=Path('/Users/adam.amster/.transform_utils/transform_utils/smartspim_lca/template/smartspim_lca_template_25.nii.gz')
)

TypeError: SliceDataset.__init__() got an unexpected keyword argument 'ls_template_path'