In [7]:
from pathlib import Path
import json
from collections import namedtuple
import re

from tqdm import tqdm
import numpy as np
from scipy.spatial import KDTree
from skimage.io import imsave

from msr_reader import OBFFile
from calmutils.stitching import get_axes_aligned_overlap
from calmutils.stitching.fusion import fuse_image
from calmutils.imageio import save_tiff_imagej
from calmutils.color import gray_images_to_rgb_composite
from calmutils.misc.visualization import get_orthogonal_projections_8bit
from utils.transform_helpers import get_scan_field_metadata, world_coords_for_pixel_spots
from utils.transform_helpers import world_transform_to_pixel_transform

# convenience namedtuple for pre-loaded image data for moving images
MovingImageData = namedtuple('MovingImageData', ['imgs', 'transform', 'coord_origin', 'coord_center', 'pixel_size'])


In [8]:
base_path_target = "/home/stumberger/ep2024/RNA_DNA_FISH_spot_detection/example/DNAFISH/"
base_path_moving = "/home/stumberger/ep2024/RNA_DNA_FISH_spot_detection/example/RNAFISH/"

msr_subdir_target = 'raw'
msr_subdir_moving = 'raw'

alignment_params_file_moving = 'alignment_parameters/alignment_parameters_global.json'

# include and exclude patterns for files
# exclude: do not process files containing this pattern, include: only process files containing a pattern
# can be used to e.g. only process overview/sted images
file_exclude_pattern_target = 'sted'
file_include_pattern_target = None
file_exclude_pattern_moving = None
file_include_pattern_moving = None

# channels to include in fused image
channels_to_include_target = (2, )
channels_to_include_moving = (1, )

# what out-of-bounds-value to put in fused images
# NOTE: using an "unnatural" number like -1 can help to distinguish empty images after fusion
oob_val = -1

# whether to fuse multiple moving images
# if False, will only transform the one with the highest overlap, ignoring other moving tiles at the border of target image
# NOTE: resulting images show weird offsets, still needs testing/fixing, but may be due to stage calibration issues on microscope
# NOTE: using only the best fitting image may help with overviews, but STED details will still be fused at the border
fuse_multiple_moving = True

# subdirecctory to save results to
out_subdir = 'aligned_beads'

# whether to save projections or not plus folder to save them to (will be subdir of out_subdir)
save_projections = True
projections_subdir = 'projection_visualization'


## 1) Load Transformation field

In [9]:
with open(Path(base_path_moving) / alignment_params_file_moving) as fd:
    parameters = json.load(fd)

# get center coords (in moving image world coordinates), build kd-tree for quick nearest lookup
center_coords_moving = [np.array(p['center_coords']) for p in parameters['transformations']]
kd_transform_centers = KDTree(center_coords_moving)

# get transformation matrix entries in same order, back to 4x4
transformations_moving = [np.array(p['parameters']).reshape((4,4)) for p in parameters['transformations']]

## 2) Load moving images

We now load all moving images and get their transformations plus some metadata

In [10]:
moving_imgs_data = []
for msr_file_moving in tqdm(list((Path(base_path_moving) / msr_subdir_moving).glob('*.msr'))):

    # check if filename doesn't match an include pattern or if it does match an exclude pattern -> skip 
    if file_include_pattern_moving is not None and not re.findall(file_include_pattern_moving, msr_file_moving.name):
        continue
    if file_exclude_pattern_moving is not None and re.findall(file_exclude_pattern_moving, msr_file_moving.name):
        continue

    with OBFFile(msr_file_moving) as reader:
        imgs = [reader.read_stack(i) for i in channels_to_include_moving]
    
    meta = get_scan_field_metadata(msr_file_moving, channels_to_include_moving[0])
    shape = np.array(imgs[0].shape)
    coord_origin = world_coords_for_pixel_spots([0,0,0], meta)[0] * 1e6
    coord_center = world_coords_for_pixel_spots(shape/2, meta)[0] * 1e6

    # find closest transform in transformation field
    _, closest_transform_idx = kd_transform_centers.query(coord_center)
    transform = transformations_moving[closest_transform_idx]

    # save images, (world coord) transform, origin, center and pixel size
    img_data = MovingImageData(imgs, transform, coord_origin, coord_center, meta.pixel_size * 1e6)
    moving_imgs_data.append(img_data)


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 30.29it/s]


## 3) Image Fusion

Now, we will go through all target images, and for each of them select overlapping moving images and transform and fuse them into an image of the same size as the target image. Results will be saved as multichannel TIFFs and optionally as PNG RGB orthogonal projections.

In [17]:
for msr_file_target in tqdm(list((Path(base_path_target) / msr_subdir_target).glob('*.msr'))):

    # check if filename doesn't match an include pattern or if it does match an exclude pattern -> skip 
    if file_include_pattern_target is not None and not re.findall(file_include_pattern_target, msr_file_target.name):
        continue
    if file_exclude_pattern_target is not None and re.findall(file_exclude_pattern_target, msr_file_target.name):
        continue

    # read channels to include in result
    with OBFFile(msr_file_target) as reader:
        imgs = [reader.read_stack(i) for i in channels_to_include_target]

    # get image metadata
    meta = get_scan_field_metadata(msr_file_target, channels_to_include_target[0])
    shape = np.array(imgs[0].shape)
    coord_origin = world_coords_for_pixel_spots([0,0,0], meta)[0] * 1e6
    coord_center = world_coords_for_pixel_spots(shape/2, meta)[0] * 1e6

    # go through all moving images, check which ones will overlap with reference after transform
    imgs_to_fuse = []
    transforms_to_fuse = []

    max_overlap = 0
    for moving_img_data in moving_imgs_data:

        # get transform in pixel units
        transform_i_moving = world_transform_to_pixel_transform(moving_img_data.transform, coord_origin, moving_img_data.coord_origin, meta.pixel_size * 1e6, moving_img_data.pixel_size)
        # check overlap (of axis-aligned transformed image)
        mins, maxs = get_axes_aligned_overlap(imgs[0].shape, moving_img_data.imgs[0].shape, None, transform_i_moving)
        # there is some overlap
        if all(mins < maxs):
            if fuse_multiple_moving:
                imgs_to_fuse.append(moving_img_data.imgs)
                transforms_to_fuse.append(transform_i_moving)

            # we only want to fuse a single moving image -> keep only 
            elif not fuse_multiple_moving and np.prod(maxs - mins) > max_overlap:
                max_overlap = np.prod(maxs - mins)
                imgs_to_fuse = [moving_img_data.imgs]
                transforms_to_fuse = [transform_i_moving]


    fused_imgs = []
    for i in range(len(channels_to_include_moving)):
        # fuse in target image bounds
        bbox = [(0,s) for s in imgs[0].shape]
        input_imgs = [imgs_i[i] for imgs_i in imgs_to_fuse]
        if len(imgs_to_fuse) > 0:
            fused_img = fuse_image(bbox, input_imgs, transforms_to_fuse, interpolation_mode='linear', dtype=np.float32, oob_val=oob_val)
        else:
            fused_img = np.full(imgs[0].shape, oob_val, dtype=np.float32)
        fused_imgs.append(fused_img)

    # make multi-channel float32 stack
    result = np.array([img.astype(np.float32) for img in imgs] + fused_imgs)

    # save as multichannel TIFF, make output folder if necessary
    out_file = Path(base_path_target) / out_subdir / (msr_file_target.stem + '_aligned.tif')
    if not out_file.parent.exists():
        out_file.parent.mkdir()
    save_tiff_imagej(out_file, result, axes='czyx', pixel_size=meta.pixel_size*1e6, distance_unit='micron')


    if save_projections:
        # get orthogonal projections for all channels
        projections = [get_orthogonal_projections_8bit(img, meta.pixel_size) for img in result]
        # make RGB composite
        rgb_projection = (gray_images_to_rgb_composite(projections) * 255).astype(np.uint8)

        # get filename, make folder if necessary, save as PNG
        out_file_projections = Path(base_path_target) / out_subdir / projections_subdir / (msr_file_target.stem + '_aligned_projections.png')
        if not out_file_projections.parent.exists():
            out_file_projections.parent.mkdir()
        imsave(out_file_projections, rgb_projection)



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:04<00:00,  4.96it/s]
