# Estimate Channel Alignment from multiple images

The first sections of this notebook match ```alignment_estimation_from_coordinate_tables.ipynb``` (```subscripts```) but it also performs additional QC and plotting.


In [None]:
# Specify input: base folder, spot detection subfolder, file pattern
# in_path = '/run/user/1000/gvfs/sftp:host=10.163.69.11/md/90/agl_data/NanoFISH/Gabi/GS666_tetraspeck_on_cells_1-50'
in_path = '/Volumes/agl_data/NanoFISH/Gabi/GS666_tetraspeck_on_cells_1-50/'
detection_subdirectory = 'spot-detection'
file_pattern = '*.csv'

# relevant column names 
filename_column = 'image_file'
channel_column = 'channel'
coordinate_columns = ["z_micron", "y_micron", "x_micron"]

# maximal distance for LAP matching between channels
matching_max_dist = 1

# we suppurt "similarity": shift, rotate, scale
# alternative: "affine" -> includes shearing, which may be undesired
# e.g., sometimes gave weird results on single layer of beads
transform_model_type = "affine"

# error threshold in RANSAC -> lower means more stringent filtering,
# but may lead to no transformation being estimated at all
residual_threshold = 0.1

# how many rounds of RANSAC to do (max)
# more may help when you have very few inliers
ransac_max_trials = 5_000

# pixel unit name
pixel_unit = 'micron'

## 1. Load & combine spot detections

First, we load and concatenate all detection tables:

In [None]:
from pathlib import Path
import pandas as pd

in_files = sorted((Path(in_path) / detection_subdirectory).glob(file_pattern))
df = pd.concat([pd.read_csv(in_file) for in_file in in_files]).reset_index(drop=True)

df

## 2. Match detections between channels

Next, we match detections between channels for all pairs of channels. We use linear assignment, but also discard matches above a maximum distance (see parameters).

This works fine for applications like chromatic shift correction but assumes small shifts so we can match purely on distance.

**TODO:** also add descriptor-based matching for larger shifts.

In [None]:
from itertools import combinations
from collections import defaultdict

import numpy as np
from scipy.optimize import linear_sum_assignment

from calmutils.localization.metrics import get_coord_distance_matrix


matched_df = []

for file_path, dfi in df.groupby(filename_column):

    for (ch1, dfi_ch1), (ch2, dfi_ch2) in combinations(dfi.groupby(channel_column), 2):

        ch_sorted = tuple(sorted((ch1, ch2)))

        coords_ch1 = dfi_ch1[coordinate_columns].values
        coords_ch2 = dfi_ch2[coordinate_columns].values

        # get distance matrix, set distances above max_dist to very large value to discourage matching
        d = get_coord_distance_matrix(coords_ch1, coords_ch2)
        d[d>matching_max_dist] = matching_max_dist * 9000

        # get optimal matching
        ci, ri = linear_sum_assignment(d)

        coords_ch1_matched = coords_ch1[ci[d[ci, ri] < matching_max_dist]]
        coords_ch2_matched = coords_ch2[ri[d[ci, ri] < matching_max_dist]]

        matched_dfi = dict(zip([f"{col}_ch1" for col in coordinate_columns], coords_ch1_matched.T)) | dict(zip([f"{col}_ch2" for col in coordinate_columns], coords_ch2_matched.T))
        matched_dfi = pd.DataFrame.from_dict(matched_dfi)
        matched_dfi["channel1"] = ch_sorted[0]
        matched_dfi["channel2"] = ch_sorted[1]
        matched_dfi["image_file"] = file_path

        matched_df.append(matched_dfi)


matched_df = pd.concat(matched_df).reset_index(names="spot_idx")
matched_df

## 3. Add FOV info to table

For transform estimation, we also need to know the FOV size of the corresponding images and the direction of the z-axis (bottom-to-top or top-to-bottom).

### ALTERNATIVE 1: Read from nd2 files


In [None]:
import nd2
from nd2.structures import ZStackLoop
from calmutils.misc.file_utils import get_common_subpath

fov_info_df = defaultdict(list)

for file_path, dfi in df.groupby(filename_column):

    # get file prefixes from remote paths in table and local mount (in_path)
    # NOTE: we assume the paths share at least some common subpath
    _, (prefix_remote, prefix_local), _ = get_common_subpath(file_path, in_path)

    file_path_local = file_path.replace(prefix_remote, prefix_local, 1)

    with nd2.ND2File(file_path_local) as reader:
        bottom_to_top = next((l.parameters.bottomToTop for l in reader.experiment if isinstance(l, ZStackLoop)), None)
        fov_pixel = np.array([reader.sizes[dim] for dim in 'ZYX'], dtype=float)
        pixel_sizes = np.array(reader.voxel_size()[::-1], dtype=float)

    fov_micron = fov_pixel * pixel_sizes

    fov_info_df['image_file'].append(file_path)
    fov_info_df['bottom_to_top'].append(bottom_to_top)

    for i, dim in enumerate('zyx'):
        fov_info_df[f'fov_micron_{dim}'].append(fov_micron[i])
        fov_info_df[f'pixel_size_micron_{dim}'].append(pixel_sizes[i])

fov_info_df = pd.DataFrame.from_dict(fov_info_df)

# join with matched df
matched_df = matched_df.merge(fov_info_df, on='image_file')

### ALTERNATIVE 2: Manually set FOV size, z-direction

In [None]:
fov_manual = [10, 133.12, 133.12]
pixel_size_manual = [0.3, 0.13, 0.13]
bottom_to_top = True

for i, dim in enumerate('zyx'):
        matched_df[f'fov_micron_{dim}'] = fov_manual[i]
        matched_df[f'pixel_size_micron_{dim}'] = pixel_size_manual[i]
matched_df['bottom_to_top'] = bottom_to_top

## 4. Estimate Transformation

In [None]:
from skimage.transform import AffineTransform, SimilarityTransform
from skimage.measure import ransac

# return AffineTransform constructor with specified dimensionality, would default to 2 otherwise
def affine_transform_nd(dimensionality):
    return lambda: AffineTransform(dimensionality=dimensionality)

# get constructor for selected transform
transform_type = {
    "affine": affine_transform_nd(3),
    "similarity": SimilarityTransform 
}[transform_model_type]

inlier_dfs = {}

transforms = {}

for (ch1, ch2), dfi in matched_df.groupby(['channel1', 'channel2']):

    # get coords & FOV info
    matched_coords_ch1 = dfi[[f"{d}_micron_ch1" for d in 'zyx']].values
    matched_coords_ch2 = dfi[[f"{d}_micron_ch2" for d in 'zyx']].values
    fov = dfi[[f"fov_micron_{d}" for d in 'zyx']].values
    bottom_to_top = dfi["bottom_to_top"].values

    z_flip_arr = np.array([-1, 1, 1])
    z_select_arr = np.array([1, 0, 0])

    # flip coordinates if not bottom-to-top then move back up by fov (only in z!)
    matched_coords_ch1 *= np.where(bottom_to_top.reshape((-1,1)), np.ones(3), z_flip_arr)
    matched_coords_ch1 += np.where(bottom_to_top.reshape((-1,1)), np.zeros(3), z_select_arr) * fov
    matched_coords_ch2 *= np.where(bottom_to_top.reshape((-1,1)), np.ones(3), z_flip_arr)
    matched_coords_ch2 += np.where(bottom_to_top.reshape((-1,1)), np.zeros(3), z_select_arr) * fov

    # TODO: correct for unequal FOVs, one possible solution is to center:
    # matched_coords_ch1 -= fov / 2
    # matched_coords_ch2 -= fov / 2

    # do ransac
    transform, inliers = ransac((matched_coords_ch1, matched_coords_ch2),
                                transform_type, 4, residual_threshold=residual_threshold, max_trials=ransac_max_trials)
    print(f'RANSAC on {ch1}->{ch2} inliers: {inliers.sum()}/{len(inliers)}')

    # save copy of matched coordinates with inlier info (for transformation field visualization)
    inlier_df = dfi.copy()
    inlier_df["inlier"] = inliers
    inlier_dfs[(ch1, ch2)] = inlier_df

    # print some distance details
    dist_before_norm = (np.linalg.norm((matched_coords_ch1[inliers] - matched_coords_ch2[inliers]), axis=1).mean())
    dist_before = (matched_coords_ch1[inliers] - matched_coords_ch2[inliers]).mean(axis=0)
    dist_after_norm = (np.linalg.norm((transform(matched_coords_ch1[inliers]) - matched_coords_ch2[inliers]), axis=1).mean())
    dist_after = (transform(matched_coords_ch1[inliers]) - matched_coords_ch2[inliers]).mean(axis=0)

    print(f'mean distance before transform: {dist_before_norm:.3f} {pixel_unit}, after: {dist_after_norm:.3f} {pixel_unit}')
    print(f'mean distance before transform: {dist_before} {pixel_unit}, after: {dist_after} {pixel_unit}')

    # put matrix of estimated transform plus inverse into results
    transforms[(ch1, ch2)] = transform.params
    transforms[(ch2, ch1)] = np.linalg.inv(transform.params)

## 5. Save Results as JSON

In [None]:
import json
from pathlib import Path

out_file = Path(in_path) / 'channel_registration_multifile-3c.json'

# TODO: also save FOV?
output = {
    'channels' : list(df[channel_column].unique()),
    # 'pixel_size' : list(pixel_size),
    'size_unit' : pixel_unit,
    'z_direction' : "bottom_to_top",
    # 'field_of_view' : list(np.array(next(iter(images.values())).shape) * pixel_size),
    'source_file': list(df[filename_column].unique()),
    'transforms' : [ {'channels' : k, 'parameters': list(v.flat)} for k,v in transforms.items()]
}

with open(out_file, 'w') as fd:
    json.dump(output, fd, indent=1)

## (Optional) Cross-Validation

Here, we go over the matched data file-by-file and perform leave-one-out CV: we estimate transforms from all but one file and use the estimation to transform the coordinates of that file. Thus, we can estimate how well the transformations generalize between files.

In [None]:
df_corrected = []

for (ch1, ch2), dfi in matched_df.groupby(['channel1', 'channel2']):

    for file in dfi.image_file.unique():

        # leave-one-out split: all other files vs. current file
        df_others = matched_df[matched_df.image_file != file]
        df_self = matched_df[matched_df.image_file == file]

        # get "training set" (others) coords
        matched_coords_ch1 = df_others[[f"{d}_micron_ch1" for d in 'zyx']].values
        matched_coords_ch2 = df_others[[f"{d}_micron_ch2" for d in 'zyx']].values
        fov = df_others[[f"fov_micron_{d}" for d in 'zyx']].values
        bottom_to_top = df_others["bottom_to_top"].values
        
        # flip coordinates if not bottom-to-top then move back up by fov (only in z!)
        matched_coords_ch1 *= np.where(bottom_to_top.reshape((-1,1)), np.ones(3), z_flip_arr)
        matched_coords_ch1 += np.where(bottom_to_top.reshape((-1,1)), np.zeros(3),z_select_arr) * fov
        matched_coords_ch2 *= np.where(bottom_to_top.reshape((-1,1)), np.ones(3), z_flip_arr)
        matched_coords_ch2 += np.where(bottom_to_top.reshape((-1,1)), np.zeros(3), z_select_arr) * fov

        # estimate transform
        transform, inliers = ransac((matched_coords_ch1, matched_coords_ch2),
                            transform_type, 4, residual_threshold=residual_threshold, max_trials=ransac_max_trials)

        # get "test set" (self) coords
        self_coords_ch1 = df_self[[f"{d}_micron_ch1" for d in 'zyx']].values
        fov = df_self[[f"fov_micron_{d}" for d in 'zyx']].values
        bottom_to_top = df_self["bottom_to_top"].values

        self_coords_ch1 *= np.where(bottom_to_top.reshape((-1,1)), np.ones(3), z_flip_arr)
        self_coords_ch1 += np.where(bottom_to_top.reshape((-1,1)), np.zeros(3), z_select_arr) * fov

        # apply transform to ch1, then undo FOV corrections
        transformed_coords = transform(self_coords_ch1)
        transformed_coords -= np.where(bottom_to_top.reshape((-1,1)), np.zeros(3), z_select_arr) * fov
        transformed_coords *= np.where(bottom_to_top.reshape((-1,1)), np.ones(3), z_flip_arr)

        df_self = df_self.copy()
        for i, d in enumerate('zyx'):
            df_self[f"{d}_micron_ch1_corr"] = transformed_coords.T[i]

        df_self['channel1'] = ch1
        df_self['channel2'] = ch2
        df_corrected.append(df_self)

df_corrected = pd.concat(df_corrected)
df_corrected



## (Optional Alternative: Apply transforms from file)

In [None]:
import json

# saved_transform_path = '/Volumes/agl_data/NanoFISH/Gabi/GS651_Nanog_2-color_2nM/channel_registration_multifile.json'
saved_transform_path = "/Volumes/agl_data/NanoFISH/Gabi/GS666_tetraspeck_on_cells_1-50/channel_registration_multifile.json"

with open(saved_transform_path) as fd:
    transform_info = json.load(fd)

df_corrected = []

for (ch1, ch2), dfi in matched_df.groupby(['channel1', 'channel2']):

    # get coords & FOV info
    matched_coords_ch1 = dfi[[f"{d}_micron_ch1" for d in 'zyx']].values
    matched_coords_ch2 = dfi[[f"{d}_micron_ch2" for d in 'zyx']].values
    fov = dfi[[f"fov_micron_{d}" for d in 'zyx']].values
    bottom_to_top = dfi["bottom_to_top"].values

    z_flip_arr = np.array([-1, 1, 1])
    z_select_arr = np.array([1, 0, 0])

    # flip coordinates if not bottom-to-top then move back up by fov (only in z!)
    matched_coords_ch1 *= np.where(bottom_to_top.reshape((-1,1)), np.ones(3), z_flip_arr)
    matched_coords_ch1 += np.where(bottom_to_top.reshape((-1,1)), np.zeros(3), z_select_arr) * fov
    matched_coords_ch2 *= np.where(bottom_to_top.reshape((-1,1)), np.ones(3), z_flip_arr)
    matched_coords_ch2 += np.where(bottom_to_top.reshape((-1,1)), np.zeros(3), z_select_arr) * fov

    transform = next(tr["parameters"] for tr in transform_info["transforms"] if tr["channels"] == [ch1, ch2])
    transform = SimilarityTransform(matrix=np.array(transform).reshape((4,4)))

    # apply transform to ch1, then undo FOV corrections
    transformed_coords = transform(matched_coords_ch1)
    transformed_coords -= np.where(bottom_to_top.reshape((-1,1)), np.zeros(3), z_select_arr) * fov
    transformed_coords *= np.where(bottom_to_top.reshape((-1,1)), np.ones(3), z_flip_arr)

    dfi = dfi.copy()
    for i, d in enumerate('zyx'):
        dfi[f"{d}_micron_ch1_corr"] = transformed_coords.T[i]

    dfi['channel1'] = ch1
    dfi['channel2'] = ch2
    df_corrected.append(dfi)

df_corrected = pd.concat(df_corrected)
df_corrected

## Mean shifts / visualization of shift components

In [None]:
matched_coords_ch1 = df_corrected[[f"{d}_micron_ch1" for d in 'zyx']].values
matched_coords_ch1_corr = df_corrected[[f"{d}_micron_ch1_corr" for d in 'zyx']].values
matched_coords_ch2 = df_corrected[[f"{d}_micron_ch2" for d in 'zyx']].values

dist_before_norm = (np.linalg.norm((matched_coords_ch1 - matched_coords_ch2), axis=1).mean())
dist_before = (matched_coords_ch1 - matched_coords_ch2).mean(axis=0)

dist_after_norm = (np.linalg.norm((matched_coords_ch1_corr - matched_coords_ch2), axis=1).mean())
dist_after = (matched_coords_ch1_corr - matched_coords_ch2).mean(axis=0)

print(f'mean distance before transform: {dist_before_norm:.3f} {pixel_unit}, after: {dist_after_norm:.3f} {pixel_unit}')
print(f'mean shift before transform: {dist_before} {pixel_unit}, after: {dist_after} {pixel_unit}')

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

df_plot = pd.DataFrame(dict(
    [ (k, v) for k, v in zip(('dz_uncorr', 'dy_uncorr','dx_uncorr'), (matched_coords_ch2 - matched_coords_ch1).T) ]
    + [ (k, v) for k, v in zip(('dz_corr', 'dy_corr','dx_corr'), (matched_coords_ch2 - matched_coords_ch1_corr).T) ]
    ))

df_plot = df_plot.melt()
df_plot[["variable", "corrected"]] = df_plot["variable"].str.split("_", expand=True)

g = sns.FacetGrid(df_plot, col="variable", hue='corrected', sharex=False, xlim=(-0.3, 0.3))
g.map(sns.histplot, "value",  alpha=.1, stat='probability', element='step')
g.add_legend()

# g.savefig("/home/david/Documents/PromoterEnhancer_Revision/shift_calibration_hist.pdf")

## Visualize shift field

In [None]:
from scipy.interpolate import LinearNDInterpolator


for (ch1, ch2), dfi in inlier_dfs.items():

    pixel_size = dfi[["pixel_size_micron_y", "pixel_size_micron_x"]].values[0]
    fov = dfi[["fov_micron_y", "fov_micron_x"]].values[0]

    shape = tuple(np.round(fov/pixel_size).astype(int))


    matched_coords_ch1 = dfi[[f"{d}_micron_ch1" for d in 'zyx']][dfi.inlier].values
    matched_coords_ch2 = dfi[[f"{d}_micron_ch2" for d in 'zyx']][dfi.inlier].values

    interp_z = LinearNDInterpolator(matched_coords_ch1[:, 1:], (matched_coords_ch2 - matched_coords_ch1).T[0], fill_value=0)
    interp_y = LinearNDInterpolator(matched_coords_ch1[:, 1:], (matched_coords_ch2 - matched_coords_ch1).T[1], fill_value=0)
    interp_x = LinearNDInterpolator(matched_coords_ch1[:, 1:], (matched_coords_ch2 - matched_coords_ch1).T[2], fill_value=0)

    sample_coords = np.stack(np.mgrid[:shape[0], :shape[1]], -1) * pixel_size

    shift_x_int = interp_x(sample_coords.reshape((-1, 2))).reshape(shape)
    shift_y_int = interp_y(sample_coords.reshape((-1, 2))).reshape(shape)
    shift_z_int = interp_z(sample_coords.reshape((-1, 2))).reshape(shape)
    # xy shift angle
    shift_angle = np.atan2(shift_x_int.flat, shift_y_int.flat).reshape(shape)

    fig, axs = plt.subplots(ncols=3, figsize=(12, 4))

    plt0 = axs[0].imshow(np.linalg.norm([shift_y_int, shift_x_int], axis=0), clim=(0, 0.25), cmap='magma')
    axs[0].axis('off')
    axs[0].set_title("XY shift magnitude across FOV")
    plt.colorbar(plt0, shrink=0.5, location='bottom')

    plt1 = axs[1].imshow(shift_angle, cmap='jet')
    axs[1].axis('off')
    axs[1].set_title("XY shift angle across FOV")
    plt.colorbar(plt1, shrink=0.5, location='bottom')

    plt2 = axs[2].imshow(shift_z_int, clim=(0, 0.5), cmap='magma')
    axs[2].axis('off')
    axs[2].set_title("Z shift across FOV")
    plt.colorbar(plt2, shrink=0.5, location='bottom')

    fig.suptitle((ch1, ch2))
    fig.tight_layout()
   
    # fig.savefig(f"/home/david/Desktop/shift_visualization_{ch1}_{ch2}.pdf")

# dfi