In [5]:
from pathlib import Path
import re
from functools import reduce
from operator import or_, add
import json

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from skimage.measure import ransac
from skimage.transform import AffineTransform, SimilarityTransform, EuclideanTransform

from calmutils.descriptors import descriptor_local_qr, match_descriptors_kd
from calmutils.stitching.registration import register_iterative
from calmutils.stitching.transform_helpers import translation_matrix

In [6]:
def combine_dicts_along_keys(*dicts, combine_function=add, error_missing_key=False):

    """
    Combine the values for each key in multiple dicts via reduction with a user-specified function.
    By default, combines all present values for a key, skipping missing,
    but can also be set to raise error if a key is not present in all dicts.
    """

    # TODO: general-purpose function, move to CalmUtils?

    # get all keys present in any dict
    all_keys = reduce(or_, [d.keys() for d in dicts])
    combined_dicts = {}

    for k in all_keys:
        # get present values for key
        present_values = [d[k] for d in dicts if k in d]

        # number of present values does not match number of dicts -> raise error if desired
        if error_missing_key and (len(present_values) != len(dicts)):
            raise ValueError(f"key '{k}' not found in all dicts.")

        # combine via reduction with combine_function
        combined_value = reduce(combine_function, present_values)
        combined_dicts[k] = combined_value

    return combined_dicts

In [21]:
detection_table1_path = '/data/agl_data/NanoFISH/Gabi/GS813_Nanog_RNA-DNA_all/20250823_DNAFISH//detections_beads/merge_global_coords.csv'
detection_table2_path = '/data/agl_data/NanoFISH/Gabi/GS813_Nanog_RNA-DNA_all/20250821_RNAFISH//detections_beads/merge_global_coords.csv'

json_save_path_g = '/data/agl_data/NanoFISH/Gabi/GS813_Nanog_RNA-DNA_all/20250823_DNAFISH/transformations_global.json'
json_save_path_l = '/data/agl_data/NanoFISH/Gabi/GS813_Nanog_RNA-DNA_all/20250823_DNAFISH//transformations_local_1round.json'

summary_df_save_path = "/data/agl_data/NanoFISH/Gabi/GS813_Nanog_RNA-DNA_all/20250823_DNAFISH/alignment_accuracy1.csv"

coordinate_columns_yx = ["y_global_um", "x_global_um"]
coordinate_column_z = "z_global_um"

coordinate_columns_pixel = ["z", "y", "x"]

image_file_column = "img"
IMAGE_ID_COLUMN = "image_id"

# to determine coverslip plane, we get the z position of a low quantile of detections
# this should correspond to beads on the coverslip
z_bottom_quantile = 0.1

# descriptor and matching parameters
n_neighbors = 4
redundancy = 0
descriptor_match_ratio = 2

ransac_max_error = 4.0
ransac_max_trials = 100_000


In [8]:
# combined column coords
coordinate_columns = [coordinate_column_z] + coordinate_columns_yx

df1 = pd.read_csv(detection_table1_path)
df2 = pd.read_csv(detection_table2_path)

len(df1[image_file_column].unique()), len(df2[image_file_column].unique())

(388, 544)

In [9]:
def get_file_stem_without_channel(path):
    """
    get the stem of a file path without ending and a suffix _ch{ch_id}.
    """

    # check for presence of channel ending and only remove if necessary
    stem = Path(path).stem
    if re.match(".*_ch[0-9]+", stem):
        file_id, ch_id = stem.rsplit("_", 1)
        return file_id
    else:
        return stem

# add cleaner image_id column (will be used in saved transforms as well)
df1[IMAGE_ID_COLUMN] = df1[image_file_column].apply(get_file_stem_without_channel)
df2[IMAGE_ID_COLUMN] = df2[image_file_column].apply(get_file_stem_without_channel)

## 0) Get pixel size and stage position transforms for all images.

Could also be done in *get global coordinates notebook* but we get them as explicit transformation matrices here.

In [10]:
pixel_size_transforms = {}
stage_position_transforms = {}

# go over both dfs
for img_id, dfi in pd.concat([df1, df2]).groupby(IMAGE_ID_COLUMN):

    # get pixel and world coords
    pixel_coords = dfi[coordinate_columns_pixel].values
    world_coords = dfi[coordinate_columns].values

    # estimate affine transform (pixel size scale + stage translation)
    at = AffineTransform(dimensionality=3)
    at.estimate(pixel_coords, world_coords)

    # 2x diag (inner extracts diag, outer makes new diag matrix with only those entries)
    mat_pixelsize = np.diag(np.diag(at.params))

    mat_stage_translation = translation_matrix(at.params[:-1, -1])

    pixel_size_transforms[img_id] = AffineTransform(mat_pixelsize)
    stage_position_transforms[img_id] = AffineTransform(mat_stage_translation)

## 1) Get coverslip position, get shift to align

First, we estimate the coverslip position $z_{cs}$ in every image by getting a low quantile of all detections in that image.
A transformation that virtually aligns the coverslip positions is the translation $(-z_{cs}, 0, 0)$.

In [11]:
z_transforms = {}

for (image_id, dfi) in df1.groupby(IMAGE_ID_COLUMN):
    z_coords = dfi[coordinate_column_z]
    z_transforms[image_id] = AffineTransform(translation_matrix([-np.quantile(z_coords, z_bottom_quantile), 0, 0]))

for (image_id, dfi) in df2.groupby(IMAGE_ID_COLUMN):
    z_coords = dfi[coordinate_column_z]
    z_transforms[image_id] = AffineTransform(translation_matrix([-np.quantile(z_coords, z_bottom_quantile), 0, 0]))

## 2) Global Alignment of two datasets via beads

In [12]:
def get_transformed_coordinates(df, transforms=None, coordinate_columns=coordinate_columns, key=IMAGE_ID_COLUMN):

    # no transform to apply -> just return values of coordinate columns
    if transforms is None:
        return df[coordinate_columns].values

    coords_tr = []

    # go through all image_ids and apply corresponding transform from transform dict
    # NOTE: sort=False to keep same order as in original dataframe
    for image_id, dfi in df.groupby(key, sort=False):
        coords = dfi[coordinate_columns].values
        coords = transforms[image_id](coords)
        coords_tr.append(coords)
    return np.concatenate(coords_tr, axis=0)

In [13]:
coords1 = get_transformed_coordinates(df1, z_transforms)
coords2 = get_transformed_coordinates(df2, z_transforms)

desc1, idx1 = descriptor_local_qr(coords1, n_neighbors, redundancy)
desc2, idx2 = descriptor_local_qr(coords2, n_neighbors, redundancy)

matches = match_descriptors_kd(desc1, desc2, max_ratio=1/descriptor_match_ratio)

len(matches)

4505

In [14]:
matched_idxs1 = idx1[matches[:, 0]]
matched_idxs2 = idx2[matches[:, 1]]

matched_coords1 = coords1[matched_idxs1]
matched_coords2 = coords2[matched_idxs2]

transform_global, inliers_global = ransac((matched_coords1, matched_coords2), EuclideanTransform, 4, ransac_max_error, max_trials=ransac_max_trials)
residuals = np.linalg.norm(transform_global(matched_coords1[inliers_global]) - matched_coords2[inliers_global], axis=1)

print(f"RANSAC inliers: {inliers_global.sum()} / {len(matched_coords1)}")
print(f"Residual error (mean, max): {residuals.mean() :.3f}, {residuals.max() :.3f}")

# use the global transform for each image in dataset1 (moving)
transforms_global = {image_id: transform_global for image_id in df1[IMAGE_ID_COLUMN].unique()}
# for dataset2 (target), append identity transform to have same number of transforms
transforms_global |= {image_id: AffineTransform(dimensionality=3) for image_id in df2[IMAGE_ID_COLUMN].unique()}

RANSAC inliers: 3192 / 4505
Residual error (mean, max): 2.159, 4.410


### Save

In [15]:
dict_to_save = {}
for img_id in pd.concat([df1, df2])[IMAGE_ID_COLUMN].unique():

    tr_info_psz = "pixel_size", pixel_size_transforms[img_id].params.ravel().tolist()
    tr_info_stage = "stage_position", stage_position_transforms[img_id].params.ravel().tolist()
    tr_info_z = "coverslip_align", z_transforms[img_id].params.ravel().tolist()
    tr_info_reg_global = "global_registration", transforms_global[img_id].params.ravel().tolist()

    dict_to_save[img_id] = [tr_info_psz, tr_info_stage, tr_info_z, tr_info_reg_global]

with open(json_save_path_g, "w") as fd:
    json.dump(dict_to_save, fd, indent=1)

In [None]:
# from matplotlib import pyplot as plt
# # import napari

# tr_combined = combine_dicts_along_keys(z_transforms, transforms_global)

# coords_tr1 = get_transformed_coordinates(df1, tr_combined)
# coords_tr2 = get_transformed_coordinates(df2, tr_combined)

# plt.scatter(*coords_tr1.T[1:], s=0.002, alpha=0.8)
# plt.scatter(*coords_tr2.T[1:], s=0.002, alpha=0.8)

# if napari.current_viewer() is not None:
#     napari.current_viewer().close()

# viewer = napari.Viewer()
# viewer.add_points(coords_tr1[:, 0:], face_color='cyan', border_color="#FFF0", size=3)
# viewer.add_points(coords_tr2[:, 0:], face_color='magenta', border_color="#FFF0", size=3)

## 3) Local Alignment

Now, we repeat the alignment as in step 2 on a per-image basis:

- we consider pairs of images from both datasets if their mean transformed coordinates differ by less than a threshold (~FOV size)
- For overlapping images, we perform descriptor matching and RANSAC
- the inlier point matches of all pairs with enough matches are used to calculate globally optimal consensus transforms (like in Multiview Reconstruction)

In [16]:
# maximal distance of the mean coordinates of images to be still considered overlapping
overlap_mean_distance_cutoff = 50

# how many points have to match (and remain after RANSAC) to consider pair of images
min_matches_local = 12

redundancy_local = 1

max_error_local = 2.0

In [17]:
from scipy.spatial import distance_matrix

# combine transforms so far
tr_combined = combine_dicts_along_keys(z_transforms, transforms_global)

# pre-apply to a copy of datasets
df1_with_tr = df1.copy()
df1_with_tr[[f"{c}_tr" for c in coordinate_columns]] = get_transformed_coordinates(df1_with_tr, tr_combined)
df2_with_tr = df2.copy()
df2_with_tr[[f"{c}_tr" for c in coordinate_columns]] = get_transformed_coordinates(df2_with_tr, tr_combined)

# get mean coords per image, select only pairs with difference less than cutoff
mean_coords_1 = df1_with_tr.groupby(IMAGE_ID_COLUMN)[[f"{c}_tr" for c in coordinate_columns]].mean().values
mean_coords_2 = df2_with_tr.groupby(IMAGE_ID_COLUMN)[[f"{c}_tr" for c in coordinate_columns]].mean().values

img_ids_1 = np.sort(df1[IMAGE_ID_COLUMN].unique())
img_ids_2 = np.sort(df2[IMAGE_ID_COLUMN].unique())
overlapping_fields = [(img_ids_1[i], img_ids_2[j]) for i,j in (np.argwhere(distance_matrix(mean_coords_1, mean_coords_2) < overlap_mean_distance_cutoff))]

In [18]:
to_optimize_round1 = {}

for img_id_1, img_id_2 in tqdm(overlapping_fields):

    coords_tr_1 = df1_with_tr[df1_with_tr[IMAGE_ID_COLUMN] == img_id_1][[f"{c}_tr" for c in coordinate_columns]].values
    coords_tr_2 = df2_with_tr[df2_with_tr[IMAGE_ID_COLUMN] == img_id_2][[f"{c}_tr" for c in coordinate_columns]].values

    desc1, idx1 = descriptor_local_qr(coords_tr_1, n_neighbors, redundancy_local, scale_invariant=True)
    desc2, idx2 = descriptor_local_qr(coords_tr_2, n_neighbors, redundancy_local, scale_invariant=True)
    matches = match_descriptors_kd(desc1, desc2, max_ratio=1/2.0)

    if len(matches) < min_matches_local:
        continue

    coords_match_1 = coords_tr_1[idx1[matches[:, 0]]]
    coords_match_2 = coords_tr_2[idx2[matches[:, 1]]]

    model, inliers = ransac((coords_match_1, coords_match_2), EuclideanTransform, 3, max_error_local, max_trials=1000)

    if inliers is None or (inliers.sum() < min_matches_local):
        continue

    coords_inliers_1 = coords_match_1[inliers]
    coords_inliers_2 = coords_match_2[inliers]
    to_optimize_round1[(img_id_1, img_id_2)] = (coords_inliers_1, coords_inliers_2)

  0%|          | 0/879 [00:00<?, ?it/s]

In [19]:
refine_transforms_round1 = register_iterative(to_optimize_round1, transform_type=EuclideanTransform, max_iterations=500)
print(f"Per-tile transforms estimated for {len(refine_transforms_round1)} images")

# add identity transforms for the images for which we did not find transform
for img_id in pd.concat([df1, df2])[IMAGE_ID_COLUMN].unique():
    if img_id not in refine_transforms_round1:
        refine_transforms_round1[img_id] = AffineTransform(dimensionality=3)

Per-tile transforms estimated for 542 images


### Save / Visualize

In [22]:
dict_to_save = {}
for img_id in pd.concat([df1, df2])[IMAGE_ID_COLUMN].unique():

    tr_info_psz = "pixel_size", pixel_size_transforms[img_id].params.ravel().tolist()
    tr_info_stage = "stage_position", stage_position_transforms[img_id].params.ravel().tolist()
    tr_info_z = "coverslip_align", z_transforms[img_id].params.ravel().tolist()
    tr_info_reg_global = "global_registration", transforms_global[img_id].params.ravel().tolist()
    tr_info_tile_round1 = "tile_registration_round1", refine_transforms_round1[img_id].params.ravel().tolist()

    dict_to_save[img_id] = [tr_info_psz, tr_info_stage, tr_info_z, tr_info_reg_global, tr_info_tile_round1]

with open(json_save_path_l, "w") as fd:
    json.dump(dict_to_save, fd, indent=1)

### Summary table

Table of coordinate differences of inliers after round 1 local alignment (for all pairs of images considered for round 1)

In [None]:
from collections import defaultdict

summary_df = defaultdict(list)

for (id1, id2), (coords1, coords2) in to_optimize_round1.items():
    coord_diff = refine_transforms_round1[id1](coords1) - refine_transforms_round1[id2](coords2)

    summary_df["image_id_1"].append(id1)
    summary_df["image_id_2"].append(id2)
    summary_df["diff_mean"].append(np.linalg.norm(coord_diff, axis=1).mean())
    summary_df["diff_max"].append(np.linalg.norm(coord_diff, axis=1).max())
    summary_df["n_inliers"].append(len(coord_diff))

summary_df = pd.DataFrame(summary_df)

summary_df.to_csv(summary_df_save_path, index=None)

summary_df

In [None]:
# average all neighbors of an image (weighted by number of inlier points)
summary_df.groupby("image_id_1").apply(lambda x: np.average(x.diff_mean, weights=x.n_inliers), include_groups=False).sort_values().head(10)

In [None]:
# from matplotlib import pyplot as plt
# import napari

# tr_combined = combine_dicts_along_keys(z_transforms, transforms_global, refine_transforms_round1)

# coords_tr1 = get_transformed_coordinates(df1, tr_combined)
# coords_tr2 = get_transformed_coordinates(df2, tr_combined)

# plt.scatter(*coords_tr1.T[1:], s=0.002, alpha=0.8)
# plt.scatter(*coords_tr2.T[1:], s=0.002, alpha=0.8)

# if napari.current_viewer() is not None:
#     napari.current_viewer().close()

# viewer = napari.Viewer()
# viewer.add_points(coords_tr1[:, 0:], face_color='cyan', border_color="#FFF0", size=3)
# viewer.add_points(coords_tr2[:, 0:], face_color='magenta', border_color="#FFF0", size=3)