# Estimate Channel Alignment from multiple images

This notebook will pool spot detections from multiple files and use them to estimate alignment parameters between channels.


In [None]:
# Specify input: base folder, spot detection subfolder, file pattern
# in_path = '/run/user/1000/gvfs/sftp:host=10.163.69.11/md/90/agl_data/NanoFISH/Gabi/GS666_tetraspeck_on_cells_1-50'
in_path = '/Volumes/agl_data/NanoFISH/Gabi/GS666_tetraspeck_on_cells_1-50/'
detection_subdirectory = 'spot-detection'
file_pattern = '*.csv'

# relevant column names 
filename_column = 'image_file'
channel_column = 'channel'
coordinate_columns = ["z_micron", "y_micron", "x_micron"]

# maximal distance for LAP matching between channels
matching_max_dist = 1

# we suppurt "similarity": shift, rotate, scale
# alternative: "affine" -> includes shearing, which may be undesired
# e.g., sometimes gave weird results on single layer of beads
transform_model_type = "affine"

# error threshold in RANSAC -> lower means more stringent filtering,
# but may lead to no transformation being estimated at all
residual_threshold = 0.1

# how many rounds of RANSAC to do (max)
# more may help when you have very few inliers
ransac_max_trials = 5_000

# pixel unit name
pixel_unit = 'micron'

## 1. Load & combine spot detections

First, we load and concatenate all detection tables:

In [None]:
from pathlib import Path
import pandas as pd

in_files = sorted((Path(in_path) / detection_subdirectory).glob(file_pattern))
df = pd.concat([pd.read_csv(in_file) for in_file in in_files]).reset_index(drop=True)

df

## 2. Match detections between channels

Next, we match detections between channels for all pairs of channels. We use linear assignment, but also discard matches above a maximum distance (see parameters).

This works fine for applications like chromatic shift correction but assumes small shifts so we can match purely on distance.

**TODO:** also add descriptor-based matching for larger shifts.

In [None]:
from itertools import combinations
from collections import defaultdict

import numpy as np
from scipy.optimize import linear_sum_assignment

from calmutils.localization.metrics import get_coord_distance_matrix


matched_df = []

for file_path, dfi in df.groupby(filename_column):

    for (ch1, dfi_ch1), (ch2, dfi_ch2) in combinations(dfi.groupby(channel_column), 2):

        ch_sorted = tuple(sorted((ch1, ch2)))

        coords_ch1 = dfi_ch1[coordinate_columns].values
        coords_ch2 = dfi_ch2[coordinate_columns].values

        # get distance matrix, set distances above max_dist to very large value to discourage matching
        d = get_coord_distance_matrix(coords_ch1, coords_ch2)
        d[d>matching_max_dist] = matching_max_dist * 9000

        # get optimal matching
        ci, ri = linear_sum_assignment(d)

        coords_ch1_matched = coords_ch1[ci[d[ci, ri] < matching_max_dist]]
        coords_ch2_matched = coords_ch2[ri[d[ci, ri] < matching_max_dist]]

        matched_dfi = dict(zip([f"{col}_ch1" for col in coordinate_columns], coords_ch1_matched.T)) | dict(zip([f"{col}_ch2" for col in coordinate_columns], coords_ch2_matched.T))
        matched_dfi = pd.DataFrame.from_dict(matched_dfi)
        matched_dfi["channel1"] = ch_sorted[0]
        matched_dfi["channel2"] = ch_sorted[1]
        matched_dfi["image_file"] = file_path

        matched_df.append(matched_dfi)


matched_df = pd.concat(matched_df).reset_index(names="spot_idx")
matched_df

## 3. Add FOV info to table

For transform estimation, we also need to know the FOV size of the corresponding images and the direction of the z-axis (bottom-to-top or top-to-bottom).

### ALTERNATIVE 1: Read from nd2 files


In [None]:
import nd2
from nd2.structures import ZStackLoop
from calmutils.misc.file_utils import get_common_subpath

fov_info_df = defaultdict(list)

for file_path, dfi in df.groupby(filename_column):

    # get file prefixes from remote paths in table and local mount (in_path)
    # NOTE: we assume the paths share at least some common subpath
    _, (prefix_remote, prefix_local), _ = get_common_subpath(file_path, in_path)

    file_path_local = file_path.replace(prefix_remote, prefix_local, 1)

    with nd2.ND2File(file_path_local) as reader:
        bottom_to_top = next((l.parameters.bottomToTop for l in reader.experiment if isinstance(l, ZStackLoop)), None)
        fov_pixel = np.array([reader.sizes[dim] for dim in 'ZYX'], dtype=float)
        pixel_sizes = np.array(reader.voxel_size()[::-1], dtype=float)

    fov_micron = fov_pixel * pixel_sizes

    fov_info_df['image_file'].append(file_path)
    fov_info_df['bottom_to_top'].append(bottom_to_top)

    for i, dim in enumerate('zyx'):
        fov_info_df[f'fov_micron_{dim}'].append(fov_micron[i])
        fov_info_df[f'pixel_size_micron_{dim}'].append(pixel_sizes[i])

fov_info_df = pd.DataFrame.from_dict(fov_info_df)

# join with matched df
matched_df = matched_df.merge(fov_info_df, on='image_file')

### ALTERNATIVE 2: Manually set FOV size, z-direction

In [None]:
fov_manual = [10, 133.12, 133.12]
pixel_size_manual = [0.3, 0.13, 0.13]
bottom_to_top = True

for i, dim in enumerate('zyx'):
        matched_df[f'fov_micron_{dim}'] = fov_manual[i]
        matched_df[f'pixel_size_micron_{dim}'] = pixel_size_manual[i]
matched_df['bottom_to_top'] = bottom_to_top

## 4. Estimate Transformation

In [None]:
from skimage.transform import AffineTransform, SimilarityTransform
from skimage.measure import ransac

# return AffineTransform constructor with specified dimensionality, would default to 2 otherwise
def affine_transform_nd(dimensionality):
    return lambda: AffineTransform(dimensionality=dimensionality)

# get constructor for selected transform
transform_type = {
    "affine": affine_transform_nd(3),
    "similarity": SimilarityTransform 
}[transform_model_type]

inlier_dfs = {}

transforms = {}

for (ch1, ch2), dfi in matched_df.groupby(['channel1', 'channel2']):

    # get coords & FOV info
    matched_coords_ch1 = dfi[[f"{d}_micron_ch1" for d in 'zyx']].values
    matched_coords_ch2 = dfi[[f"{d}_micron_ch2" for d in 'zyx']].values
    fov = dfi[[f"fov_micron_{d}" for d in 'zyx']].values
    bottom_to_top = dfi["bottom_to_top"].values

    z_flip_arr = np.array([-1, 1, 1])
    z_select_arr = np.array([1, 0, 0])

    # flip coordinates if not bottom-to-top then move back up by fov (only in z!)
    matched_coords_ch1 *= np.where(bottom_to_top.reshape((-1,1)), np.ones(3), z_flip_arr)
    matched_coords_ch1 += np.where(bottom_to_top.reshape((-1,1)), np.zeros(3), z_select_arr) * fov
    matched_coords_ch2 *= np.where(bottom_to_top.reshape((-1,1)), np.ones(3), z_flip_arr)
    matched_coords_ch2 += np.where(bottom_to_top.reshape((-1,1)), np.zeros(3), z_select_arr) * fov

    # TODO: correct for unequal FOVs, one possible solution is to center:
    # matched_coords_ch1 -= fov / 2
    # matched_coords_ch2 -= fov / 2

    # do ransac
    transform, inliers = ransac((matched_coords_ch1, matched_coords_ch2),
                                transform_type, 4, residual_threshold=residual_threshold, max_trials=ransac_max_trials)
    print(f'RANSAC on {ch1}->{ch2} inliers: {inliers.sum()}/{len(inliers)}')

    # save copy of matched coordinates with inlier info (for transformation field visualization)
    inlier_df = dfi.copy()
    inlier_df["inlier"] = inliers
    inlier_dfs[(ch1, ch2)] = inlier_df

    # print some distance details
    dist_before_norm = (np.linalg.norm((matched_coords_ch1[inliers] - matched_coords_ch2[inliers]), axis=1).mean())
    dist_before = (matched_coords_ch1[inliers] - matched_coords_ch2[inliers]).mean(axis=0)
    dist_after_norm = (np.linalg.norm((transform(matched_coords_ch1[inliers]) - matched_coords_ch2[inliers]), axis=1).mean())
    dist_after = (transform(matched_coords_ch1[inliers]) - matched_coords_ch2[inliers]).mean(axis=0)

    print(f'mean distance before transform: {dist_before_norm:.3f} {pixel_unit}, after: {dist_after_norm:.3f} {pixel_unit}')
    print(f'mean distance before transform: {dist_before} {pixel_unit}, after: {dist_after} {pixel_unit}')

    # put matrix of estimated transform plus inverse into results
    transforms[(ch1, ch2)] = transform.params
    transforms[(ch2, ch1)] = np.linalg.inv(transform.params)

## 5. Save Results as JSON

In [None]:
import json
from pathlib import Path

out_file = Path(in_path) / 'channel_registration_multifile-3c.json'

# TODO: also save FOV?
output = {
    'channels' : list(df[channel_column].unique()),
    # 'pixel_size' : list(pixel_size),
    'size_unit' : pixel_unit,
    'z_direction' : "bottom_to_top",
    # 'field_of_view' : list(np.array(next(iter(images.values())).shape) * pixel_size),
    'source_file': list(df[filename_column].unique()),
    'transforms' : [ {'channels' : k, 'parameters': list(v.flat)} for k,v in transforms.items()]
}

with open(out_file, 'w') as fd:
    json.dump(output, fd, indent=1)