# **Step 0:** Definitions, Imports

In [2]:
from operator import sub
import numpy as np
from nd2reader import ND2Reader
import itk
from scipy.spatial.transform import Rotation as R

def elastix_similarity_to_matrix(elastix_parameter_object, ndim=3):
    if ndim==3:
        return elastix_similarity_to_matrix_3d(elastix_parameter_object)
    elif ndim==2:
        return elastix_similarity_to_matrix_2d(elastix_parameter_object)
    else:
        raise ValueError('only 2D/3D similarity transform supported at the moment')

def elastix_similarity_to_matrix_2d(elastix_parameter_object):

    # check if we actually have the right type of transform    
    transform_type, = elastix_parameter_object.GetParameter(0, 'Transform')
    n_parameters, = elastix_parameter_object.GetParameter(0, 'NumberOfParameters')

    if transform_type != 'SimilarityTransform' or int(n_parameters) != 4:
        raise ValueError('only 2D similarity transform supported at the moment')

    # get center of rotation, NOTE: GetParameter returns tuple of str, map to float
    cx, cy = map(float, elastix_parameter_object.GetParameter(0, 'CenterOfRotationPoint'))
    # get rotation quat, translation, scale
    s, rot, tx, ty = map(float, elastix_parameter_object.GetParameter(0, 'TransformParameters'))

    # build augmented matrices for individual steps
    # move to center
    c_mat = np.eye(3)
    c_mat[:-1, -1] = [cy, cx]

    # move translation
    t_mat = np.eye(3)
    t_mat[:-1, -1] = [ty, tx]

    # scale
    s_mat = np.diag([s, s, 1])

    # rotation
    r_mat = R.from_euler('zyx', [rot, 0, 0]).as_matrix()
    # explicitly set bottom row again, otherwise affine_transform complained about it being not exactly 0,0,1
    r_mat[2] = [0,0,1]

    # final (similarity) transform matrix constructed as in elastix documentation, right-to-left!
    # add c @ add t @ scale @ r @ sub c
    mat = c_mat @ t_mat @ s_mat @ r_mat @ np.linalg.inv(c_mat)

    return mat

def elastix_similarity_to_matrix_3d(elastix_parameter_object):

    # check if we actually have the right type of transform    
    transform_type, = elastix_parameter_object.GetParameter(0, 'Transform')
    n_parameters, = elastix_parameter_object.GetParameter(0, 'NumberOfParameters')

    if transform_type != 'SimilarityTransform' or int(n_parameters) != 7:
        raise ValueError('only 3D similarity transform supported at the moment')

    # get center of rotation, NOTE: GetParameter returns tuple of str, map to float
    cx, cy, cz = map(float, elastix_parameter_object.GetParameter(0, 'CenterOfRotationPoint'))
    # get rotation quat, translation, scale
    qx, qy, qz, tx, ty, tz, s = map(float, elastix_parameter_object.GetParameter(0, 'TransformParameters'))

    # build augmented matrices for individual steps
    # move to center
    c_mat = np.eye(4)
    c_mat[:-1, -1] = [cz, cy, cx]

    # move translation
    t_mat = np.eye(4)
    t_mat[:-1, -1] = [tz, ty, tx]

    # scale
    s_mat = np.diag([s, s, s, 1])

    # rotation
    r_mat = np.eye(4)
    r_mat[:3, :3] = R.from_quat([qz, qy, qx, 1]).as_matrix()

    # final (similarity) transform matrix constructed as in elastix documentation, right-to-left!
    # add c @ add t @ scale @ r @ sub c
    mat = c_mat @ t_mat @ s_mat @ r_mat @ np.linalg.inv(c_mat)

    return mat

def world_coordinate_transform_to_pixel(transform_matrix, pixel_size):
    pixel_scale_mat = np.diag(list(pixel_size) + [1])
    mat = np.linalg.inv(pixel_scale_mat) @ transform_matrix @ pixel_scale_mat
    return mat

def elastix_registration(img1, img2, pixel_size):

    # numpy to ITK
    img_target = img1.astype(np.float32)
    img_moving = img2.astype(np.float32)
    img_target = itk.image_from_array(img_target)
    img_moving = itk.image_from_array(img_moving)
    # set pixel size to get transform in world coordinate units
    img_target.SetSpacing(pixel_size[::-1])
    img_moving.SetSpacing(pixel_size[::-1])

    # construct ITK parameter object
    elastix_parameters = itk.ParameterObject.New()
    # add transform, overwrite affine defaults to get similarity
    similarity_parameter_map = elastix_parameters.GetDefaultParameterMap('affine')
    similarity_parameter_map['Transform'] = ['SimilarityTransform']
    similarity_parameter_map['NumberOfSpatialSamples'] = [f'{8192}']
    elastix_parameters.AddParameterMap(similarity_parameter_map)

    # Call registration function
    _, estimated_transform_parameters = itk.elastix_registration_method(img_target, img_moving, parameter_object=elastix_parameters)

    return elastix_similarity_to_matrix(estimated_transform_parameters, img1.ndim)

# **Step 1:** Read images of different channels

We want a dictionary of images of the different channels, like so:

```
images = {
    'channel_1_name': image_channel_1,
    'channel_2_name': image_channel_2,
    ...
}
```

Additionally, we want an array of pixel sizes. The unit can be arbitrary, but should be specified:

```
pixel_size = [pixel_size_z, pixel_size_y, pixel_size_x]
pixel_unit = 'micron'
```

Also, we want to know whether we are moving towards to sample with increasing z planes or away from sample (to coverslip), so the saved transformations can also be applied to images imaged in the opposite direction. Note that ```z_direction``` can be left blank, i.e., set to ```None```, if the direction is unknown - then it can not be considered when applying the transformation though.

```
z_direction = 'to_sample' | 'from_sample' | None
```

**We support several input options, also check ```chromatic_aberration_estimation.ipynb``` for other file formats**

## **Option 1:** Nikon nd2 files

In [3]:
image_path = "/home/stumberger/ep2024/example/chrom_shift_reference/23AM09-03_003.nd2"

# spinning disk data may have additional magnification of 1.5x
# leave at 1.0 unless you are sure you used the extra zoom
magnification = 1.0

# set to True to estimate 2d transformation in max. projection of data
do_maxprojection_2d = False

# read all channels into dict of channel_name -> img
images = {}
with ND2Reader(image_path) as reader:

    reader.bundle_axes = ['z', 'y', 'x']
    reader.iter_axes = ['c']

    for i, channel_name in enumerate(reader.metadata['channels']):
        img = np.array(reader[i])
        if do_maxprojection_2d:
            img = img.max(axis=0)
        images[channel_name.strip().replace(' ', '-')] = img

    psz_xy = reader.metadata['pixel_microns'] / magnification
    # difference of z position of first two planes -> z-spacing
    psz_z = sub(*reader.metadata['z_coordinates'][:2])
    z_direction = 'top_to_bottom' if psz_z > 0 else 'bottom_to_top'
    # for pixel size, use absolute spacing
    psz_z = abs(psz_z)

# pixel size to array
pixel_size = np.array([psz_xy, psz_xy]) if do_maxprojection_2d else np.array([psz_z, psz_xy, psz_xy])
pixel_unit = 'micron'

print(f'read nd2 file with {len(images)} channels: {list(images.keys())}')
print('image shapes:')
for channel_name, v in images.items():
    print(f'{channel_name}: {v.shape}')
print(f'pixel size: {pixel_size} {pixel_unit}')

read nd2 file with 4 channels: ['405-CSU-W1', '488-CSU-W1', '561-CSU-W1', '640-CSU-W1']
image shapes:
405-CSU-W1: (71, 1024, 1024)
488-CSU-W1: (71, 1024, 1024)
561-CSU-W1: (71, 1024, 1024)
640-CSU-W1: (71, 1024, 1024)
pixel size: [0.225 0.13  0.13 ] micron


## Optional: View images in napari

Confirm that you loaded the correct data

In [5]:
import napari

colormaps_default = ['blue', 'green', 'yellow', 'red', 'cyan', 'magenta']

if napari.current_viewer() is not None:
    napari.current_viewer().close()

viewer = napari.Viewer()
for i, (channel_name, image) in enumerate(images.items()):
    viewer.add_image(image, colormap=colormaps_default[i], name=channel_name, blending='additive', scale=pixel_size)


ModuleNotFoundError: No module named 'napari'

# **Step 2:** Perform Alignment

In [4]:
from itertools import combinations

transforms = {}
for (ch1, ch2) in combinations(images.keys(), 2):

    img1 = images[ch1]
    img2 = images[ch2]

    transform = elastix_registration(img1, img2, pixel_size)

    # NOTE: elastix seems to return img1 -> img2 transform
    transforms[(ch1, ch2)] = transform
    transforms[(ch2, ch1)] = np.linalg.inv(transform)

    print(f'estimated similarity transform between {ch1} and {ch2} with elastix')

estimated similarity transform between 405-CSU-W1 and 488-CSU-W1 with elastix
estimated similarity transform between 405-CSU-W1 and 561-CSU-W1 with elastix
estimated similarity transform between 405-CSU-W1 and 640-CSU-W1 with elastix
estimated similarity transform between 488-CSU-W1 and 561-CSU-W1 with elastix
estimated similarity transform between 488-CSU-W1 and 640-CSU-W1 with elastix
estimated similarity transform between 561-CSU-W1 and 640-CSU-W1 with elastix


# (Optional) **Step 3:** View Aligned Images to verify

In [24]:
try:
    from dask_image.ndinterp import affine_transform
    print('will use dask-image for image transformation')
except ImportError:
    from scipy.ndimage import affine_transform
    print('will use scipy for image transformation, consider dask-image for higher speed')

reference_channel = '561-CSU-W1'

images_aligned = {}
for ch, image in images.items():

    if ch == reference_channel:
        images_aligned[ch] = image
        print(f'keep image of channel {ch} as-is (reference)')
        continue
        
    # NOTE: we want the inverse transform from ch to reference, i.e. the transform reference -> ch
    mat = transforms[(reference_channel, ch)]
    mat = world_coordinate_transform_to_pixel(mat, pixel_size)

    image_transformed = affine_transform(image, mat, order=2)
    images_aligned[ch] = np.array(image_transformed)

    print(f'aligned image of channel {ch}')

will use scipy for image transformation, consider dask-image for higher speed
keep image of channel 561-CSU-W1 as-is (reference)
aligned image of channel 640-CSU-W1


In [None]:
import napari

colormaps_default = ['blue', 'green', 'yellow', 'red', 'cyan', 'magenta']

if napari.current_viewer() is not None:
    napari.current_viewer().close()

viewer = napari.Viewer()
for i, (channel_name, image) in enumerate(images_aligned.items()):
    viewer.add_image(image, colormap=colormaps_default[i], name=channel_name, scale=pixel_size, blending='additive')

# **Step4:** Save Transformation parameters

In [5]:
import json
from pathlib import Path

if do_maxprojection_2d:
    out_file = Path(image_path).parent / (Path(image_path).stem + '_channel_registration_maxproj.json')
else:
    out_file = Path(image_path).parent / (Path(image_path).stem + '_channel_registration.json')
# out_file = '/data/agl_data/NanoFISH/Gabi/GS534_beads_coloc/sparse_channel_registration_560_640.json'

output = {
    'channels' : list(images.keys()),
    'pixel_size' : list(pixel_size),
    'size_unit' : pixel_unit,
    'z_direction' : z_direction,
    'field_of_view' : list(np.array(next(iter(images.values())).shape) * pixel_size),
    'source_file': image_path,
    'transforms' : [ {'channels' : k, 'parameters': list(v.flat)} for k,v in transforms.items()]
}

with open(out_file, 'w') as fd:
    json.dump(output, fd, indent=1)