In [None]:
from pathlib import Path
import json
from collections import namedtuple
import re

from h5py import File
from glob import glob
from utils.transform_helpers import get_scan_field_metadata_h5

from tqdm.notebook import tqdm
import numpy as np
from scipy.spatial import KDTree
from skimage.io import imsave

from msr_reader import OBFFile
from utils.transform_helpers import get_scan_field_metadata, world_coords_for_pixel_spots
from calmutils.stitching import get_axes_aligned_overlap
from calmutils.stitching.fusion import fuse_image
from calmutils.imageio import save_tiff_imagej
from calmutils.color import gray_images_to_rgb_composite
from utils.transform_helpers import world_transform_to_pixel_transform
from calmutils.misc.visualization import get_orthogonal_projections_8bit

# convenience namedtuple for pre-loaded image data for moving images
MovingImageData = namedtuple('MovingImageData', ['imgs', 'transform', 'coord_origin', 'coord_center', 'pixel_size'])


In [None]:
base_path_target = "/data/agl_data/NanoFISH/Gabi/GS813_Nanog_RNA-DNA_all/20250823_DNAFISH/"
base_path_moving = "/data/agl_data/NanoFISH/Gabi/GS813_Nanog_RNA-DNA_all/20250821_RNAFISH/"

raw_subdir_target = 'raw'
raw_subdir_moving = 'raw'

alignment_params_file = '/data/agl_data/NanoFISH/Gabi/GS813_Nanog_RNA-DNA_all/20250821_RNAFISH/transformations/transformations_local_1round.json'

# include and exclude patterns for files (dataset id in H5)
# exclude: do not process files containing this pattern, include: only process files containing a pattern
# can be used to e.g. only process overview/sted images
exclude_pattern_target =  "sted"
include_pattern_target = None
exclude_pattern_moving = None
include_pattern_moving = None

# channels to include in fused image
channels_to_include_target = (0,1)
channels_to_include_moving = (0, )

# channels_to_include_target = (2,)
# channels_to_include_moving = (1, )

# what out-of-bounds-value to put in fused images
# NOTE: using an "unnatural" number like -1 can help to distinguish empty images after fusion
oob_val = -1

# whether to fuse multiple moving images
# if False, will only transform the one with the highest overlap, ignoring other moving tiles at the border of target image
# NOTE: resulting images show weird offsets, still needs testing/fixing, but may be due to stage calibration issues on microscope
# NOTE: using only the best fitting image may help with overviews, but STED details will still be fused at the border
fuse_multiple_moving = True

# subdirecctory to save results to
out_subdir = 'aligned_confocal1'

# whether to save projections or not plus folder to save them to (will be subdir of out_subdir)
save_projections = True
projections_subdir = 'vis'

## 1) Load Transformations per image id

In [None]:
with open(alignment_params_file) as fd:
    transformation_parameters = json.load(fd)

transforms = {}
for img_id, transform_list in transformation_parameters.items():

    mat = np.eye(4)
    # transforms have shape (name, flat parameters)
    # we go through all except the first two (pixel size, stage coords) in reverse
    # and concatenate through multiplication
    for (_, tr) in transform_list[:1:-1]:
        mat @= np.array(tr).reshape(4,4)

    transforms[img_id] = mat

## 2) Load moving images

We now load all moving images and get their transformations plus some metadata

In [None]:
moving_imgs_data = []

for h5_file_moving in list((Path(base_path_moving) / raw_subdir_moving).glob('*.h5')):

    print(f"reading moving images from {h5_file_moving}")

    with File(h5_file_moving) as fd:
        acquisition_ids = list(fd["experiment"].keys())

    # check that achisition id contains include pattern and does match an exclude pattern
    if include_pattern_moving is not None:
        acquisition_ids = [acquisition_id for acquisition_id in acquisition_ids if re.findall(include_pattern_moving, acquisition_id)]
    if exclude_pattern_moving is not None:
        acquisition_ids = [acquisition_id for acquisition_id in acquisition_ids if not re.findall(exclude_pattern_moving, acquisition_id)]

    for acquisition_id in tqdm(acquisition_ids):

        # get metadata
        meta = get_scan_field_metadata_h5(h5_file_moving, acquisition_id)

        # load pixels
        with File(h5_file_moving) as fd:
            dataset = fd[f"experiment/{acquisition_id}/0"]
            imgs = [np.array(dataset[str(i)]).squeeze() for i in channels_to_include_moving]

        shape = np.array(imgs[0].shape)
        coord_origin = world_coords_for_pixel_spots([0,0,0], meta)[0] * 1e6
        coord_center = world_coords_for_pixel_spots(shape/2, meta)[0] * 1e6

        # find matching transform: construct filename stem plus first level of acquisition id
        # here, we check if acquisition_id is new style (field_X_sted_Y) or old (fieldX_stedY)
        if acquisition_id.split("_")[1].isnumeric():
            first_acquisition_id = "_".join(acquisition_id.split("_")[:2])
        else:
            first_acquisition_id = acquisition_id.split("_")[0]
        img_id = h5_file_moving.stem + f"_{first_acquisition_id}"
        transform = transforms[img_id]

        # save images, (world coord) transform, origin, center and pixel size
        img_data = MovingImageData(imgs, transform, coord_origin, coord_center, meta.pixel_size * 1e6)
        moving_imgs_data.append(img_data)

## 3) Image Fusion

Now, we will go through all target images, and for each of them select overlapping moving images and transform and fuse them into an image of the same size as the target image. Results will be saved as multichannel TIFFs and optionally as PNG RGB orthogonal projections.

In [None]:
for h5_file_target in (Path(base_path_target) / raw_subdir_target).glob('*.h5'):

    print(f"aligning to dataset {h5_file_target}")

    with File(h5_file_target) as fd:
        acquisition_ids = list(fd["experiment"].keys())

    # check that achisition id contains include pattern and does match an exclude pattern
    if include_pattern_target is not None:
        acquisition_ids = [acquisition_id for acquisition_id in acquisition_ids if re.findall(include_pattern_target, acquisition_id)]
    if exclude_pattern_target is not None:
        acquisition_ids = [acquisition_id for acquisition_id in acquisition_ids if not re.findall(exclude_pattern_target, acquisition_id)]

    for acquisition_id in tqdm(acquisition_ids):

        # get metadata
        meta = get_scan_field_metadata_h5(h5_file_target, acquisition_id)

        # load pixels
        with File(h5_file_target) as fd:
            dataset = fd[f"experiment/{acquisition_id}/0"]
            imgs = [np.array(dataset[str(i)]).squeeze() for i in channels_to_include_target]

        shape = np.array(imgs[0].shape)
        coord_origin = world_coords_for_pixel_spots([0,0,0], meta)[0] * 1e6
        coord_center = world_coords_for_pixel_spots(shape/2, meta)[0] * 1e6


        # find matching transform: construct filename stem plus first level of acquisition id
        # here, we check if acquisition_id is new style (field_X_sted_Y) or old (fieldX_stedY)
        if acquisition_id.split("_")[1].isnumeric():
            first_acquisition_id = "_".join(acquisition_id.split("_")[:2])
        else:
            first_acquisition_id = acquisition_id.split("_")[0]
        img_id = h5_file_target.stem + f"_{first_acquisition_id}"
        transform_target = transforms[img_id]

        # go through all moving images, check which ones will overlap with reference after transform
        imgs_to_fuse = []
        transforms_to_fuse = []

        max_overlap = 0
        for moving_img_data in moving_imgs_data:

            combined_tr = np.linalg.inv(transform_target) @ moving_img_data.transform

            # get transform in pixel units
            transform_i_moving = world_transform_to_pixel_transform(combined_tr, coord_origin, moving_img_data.coord_origin, meta.pixel_size * 1e6, moving_img_data.pixel_size)
            # check overlap (of axis-aligned transformed image)
            mins, maxs = get_axes_aligned_overlap(imgs[0].shape, moving_img_data.imgs[0].shape, None, transform_i_moving)
            # there is some overlap
            if all(mins < maxs):
                if fuse_multiple_moving:
                    imgs_to_fuse.append(moving_img_data.imgs)
                    transforms_to_fuse.append(transform_i_moving)

                # we only want to fuse a single moving image -> keep only 
                elif not fuse_multiple_moving and np.prod(maxs - mins) > max_overlap:
                    max_overlap = np.prod(maxs - mins)
                    imgs_to_fuse = [moving_img_data.imgs]
                    transforms_to_fuse = [transform_i_moving]

        fused_imgs = []
        for i in range(len(channels_to_include_moving)):
            # fuse in target image bounds
            bbox = [(0,s) for s in imgs[0].shape]
            input_imgs = [imgs_i[i] for imgs_i in imgs_to_fuse]
            if len(imgs_to_fuse) > 0:
                fused_img = fuse_image(input_imgs, transforms_to_fuse, bbox=bbox, interpolation_mode='linear', dtype=np.float32, oob_val=oob_val)
            else:
                fused_img = np.full(imgs[0].shape, oob_val, dtype=np.float32)
            fused_imgs.append(fused_img)

        # make multi-channel float32 stack
        result = np.array([img.astype(np.float32) for img in imgs] + fused_imgs)

        # save as multichannel TIFF, make output folder if necessary
        out_file = Path(base_path_target) / out_subdir / (h5_file_target.stem + f"_{acquisition_id}" + '_aligned.tif')
        if not out_file.parent.exists():
            out_file.parent.mkdir()
        save_tiff_imagej(out_file, result, axes='czyx', pixel_size=meta.pixel_size*1e6, distance_unit='micron')


        if save_projections:
            # get orthogonal projections for all channels
            projections = [get_orthogonal_projections_8bit(img, meta.pixel_size) for img in result]
            # make RGB composite
            rgb_projection = (gray_images_to_rgb_composite(projections,color_names=['magenta','yellow','cyan']) * 255).astype(np.uint8)

            # get filename, make folder if necessary, save as PNG
            out_file_projections = Path(base_path_target) / out_subdir / projections_subdir / (h5_file_target.stem + f"_{acquisition_id}" + '_aligned_projections.png')
            if not out_file_projections.parent.exists():
                out_file_projections.parent.mkdir()
            imsave(out_file_projections, rgb_projection)
