<div style="display: inline; justify-content: space-between;">
    <img src="assets/jupyter_logo.png" width="60px;"/>
    <span>&nbsp;</span>
    <img src="assets/cruk_logo.jpg" width="260px" style="padding: 4px"/>
    <span>&nbsp;</span>
    <img src="assets/ioa_logo.png" width="80px"/>
</div>

# Image Registration

In this tutorial we are going to use two nuclear images from MerFISH, observed in two consecutive field of view observations, find the relative offsets between them and create a mosaic

## Module Imports

* [Zarr](https://zarr.readthedocs.io/en/stable/) -- chunked, compressed, N-dimensional arrays
* [Matplotlib](https://matplotlib.org), [Astropy](https://www.astropy.org) -- plotting
* [Scipy](https://www.scipy.org) -- image processing

In [1]:
%info_versions -p zarr matplotlib astropy scipy numpy

Python                        : 3.7.3 64bit GCC 7.3.0
OS                            : Linux 3.10.0 862.14.4.el7.x86_64 x86_64 with debian buster sid

astropy                       : 3.2.1
matplotlib                    : 3.1.1
numpy                         : 1.17.1
scipy                         : 1.3.1
zarr                          : 2.3.2


In [None]:
import zarr

import matplotlib.pyplot as plt
import numpy as np
from astropy.visualization import ZScaleInterval
from scipy.ndimage import fourier_shift, median_filter, shift, label, fourier_gaussian, find_objects, rotate

## Read images

In [None]:
z = zarr.open('/data/meds1_b/imaxt/merfish/test_merFISH_data', 'r')

First level is the field of views:

In [None]:
[*z.groups()]

Then for each field of views, the z planes:

In [None]:
[*z['fov=0'].groups()]

Each plane contains 8 cycle images:

In [None]:
[*z['fov=0/z=0'].groups()]

Each cycle contains 4 images: nuclei, microbeads and two bit observations:

In [None]:
[*z['fov=0/z=0/cycle=0'].groups()]

Metadata can be acessed from ``z.attrs``

In [None]:
z.attrs['bitnames']

In [None]:
z.attrs['orig']

In [None]:
import pandas as pd

codebook = z.attrs['codebook']
pd.DataFrame.from_dict(codebook)

In [None]:
im0 = z['fov=0/z=0/cycle=0/nuclei/raw']
im1 = z['fov=1/z=0/cycle=0/nuclei/raw']

In [None]:
zs = ZScaleInterval()
vmin, vmax = zs.get_limits(im0)
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(18,16))
ax[0].imshow(im0, vmin=vmin, vmax=vmax)
ax[1].imshow(im1, vmin=vmin, vmax=vmax);

## Compute offsets

In order to compute offsets, one method is take as reference the first image and then move the second image...

This, known as cross-correlation, is an expensive operation in the image space, but using the Fourier transform, it is a simple multiplication in the frequency domain, i.e.:

\begin{equation}
\mathcal{F}(f * g) = \mathcal{F}(f) \mathcal{F}(g)
\end{equation}

Tiles are transformed into the frequency domain space by discrete fourier transform $\mathcal{F}_0$ and $\mathcal{F}_1$

In [None]:
src_image = np.array(im0, dtype=np.complex128, copy=False)
target_image = np.array(im1, dtype=np.complex128, copy=False)
F_0 = np.fft.fftn(src_image)
F_1 = np.fft.fftn(target_image)
F_0_s = np.fft.fftshift(F_0)
F_1_s = np.fft.fftshift(F_1)

In [None]:
F_0_s = np.fft.fftshift(F_0)
F_1_s = np.fft.fftshift(F_1)
zs = ZScaleInterval()
vmin, vmax = zs.get_limits(np.abs(F_0_s))
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(18,16))
ax[0].imshow(np.abs(F_0_s), vmin=vmin, vmax=3*vmax, extent=(-np.pi, np.pi, -np.pi, np.pi))
ax[1].imshow(np.abs(F_1_s), vmin=vmin, vmax=3*vmax, extent=(-np.pi, np.pi, -np.pi, np.pi));

In [None]:
F_0 = fourier_gaussian(F_0, 5)
F_1 = fourier_gaussian(F_1, 5)

In [None]:
F_0_s = np.fft.fftshift(F_0)
F_1_s = np.fft.fftshift(F_1)
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(18,16))
ax[0].imshow(np.abs(F_0_s), vmin=vmin, vmax=3*vmax, extent=(-np.pi, np.pi, -np.pi, np.pi))
ax[1].imshow(np.abs(F_1_s), vmin=vmin, vmax=3*vmax, extent=(-np.pi, np.pi, -np.pi, np.pi));

Cross correlation $\Phi_{10} = \mathcal{F}_1 \times \mathcal{F}_0^*$ and auto-correlation terms $\Phi_{00} = \mathcal{F}_0 \times \mathcal{F}_0^*$ and $\Phi_{11} = \mathcal{F}_1 \times \mathcal{F}_1^*$

In [None]:
phi_10 = F_1 * F_0.conj()
phi_01 = F_0 * F_1.conj()
phi_00 = F_0 * F_0.conj()
phi_11 = F_1 * F_1.conj()

An enhanced auto-correlation terms are used to compute an enhanced cross-correlation term

\begin{equation}
P = \frac{\Phi_{10}}{\sqrt{\Phi_{00} * \Phi_{11}} + \epsilon}
\end{equation}

In [None]:
P = phi_01 / (np.sqrt(phi_00 * phi_11) + 1e-10)
P = fourier_gaussian(P, 5)
enhanced_correlation = np.fft.ifftn(P)

In [None]:
corr = enhanced_correlation.real
zs = ZScaleInterval()
vmin, vmax = zs.get_limits(corr)
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(18,16))
ax[0].imshow(corr, vmin=vmin, vmax=3*vmax)
ax[1].imshow(corr[1800:1900,0:100], vmin=vmin, vmax=3*vmax, extent=(0,100,1900,1800));

In [None]:
from imaxt_image.registration import find_shift

res = find_shift(im0, im1, overlap=(0.08, 0.12), border_width=20)
res

In [None]:
maxima = (int(res['y']), int(res['x']))
plt.plot(np.mean(corr[maxima[0]-5:maxima[0]+5,maxima[1]-10:maxima[1]+10], axis=0));

In [None]:
from skimage.transform import warp, AffineTransform

c0 = np.ones_like(im0)

image0=warp(im0, AffineTransform(), output_shape=(4000,2100))
conf0 = warp(c0, AffineTransform(), output_shape=(4000,2100)) + 1e-10
tform = AffineTransform(translation=(maxima[1], maxima[0]))
image = warp(im1, tform.inverse, output_shape=(4000,2100))
conf = warp(c0, tform.inverse, output_shape=(4000,2100)) + 1e-10


plt.figure(figsize=(12,12))
plt.imshow((image+image0)/(conf0+conf));

# Parallel Computation using Dask

## Starting the Dask Cluster

In [1]:
from distributed import Client
from dask import delayed, compute, visualize
import dask.array as da
from dask_kubernetes import KubeCluster

In [2]:
cluster = KubeCluster(n_workers=3)

In [None]:
cluster

In [None]:
client = Client(cluster)

In [None]:
#futures = []
#for j in range(1):
#    for i in range(12-1):
#        im0 = z[f'fov={i}/z=0/cycle=0/nuclei/raw']
#        im1 = z[f'fov={i+1}/z=0/cycle=0/nuclei/raw']
#        fut = client.submit(find_shift, im0, im1, overlap=(0.08, 0.12), border_width=20)
#        futures.append([i, i+1, fut])

In [None]:
#offsets = [(i,j,fut.result()) for i,j,fut in futures]
#offsets

In [None]:
from skimage.transform import warp, AffineTransform

@delayed(nout=2)
def abs_offset(xoff, yoff, offsets):
    x, y = offsets['x'], offsets['y']
    xoff += x
    yoff += y
    return xoff, yoff

@delayed(nout=2)
def get_warp(im, x, y):
    tform = AffineTransform(translation=(x, y))
    image = warp(im, tform.inverse, output_shape=(8000, 6000))
    conf = warp(np.ones_like(im), tform.inverse, output_shape=(8000, 6000))
    return image, conf
    
images = []
conf = []
for i in range(12):
    if i == 0:
        im1 = z[f'fov={i}/z=0/cycle=0/nuclei/raw']
        xoff, yoff = 0, 200
    else:
        im0 = z[f'fov={i-1}/z=0/cycle=0/nuclei/raw']
        im1 = z[f'fov={i}/z=0/cycle=0/nuclei/raw']
        offsets = delayed(find_shift)(im0, im1, overlap=(0.08, 0.12), border_width=20)
        xoff, yoff = abs_offset(xoff, yoff, offsets)
    img, cf = get_warp(im1, xoff, yoff)
    arr = da.from_delayed(img, (8000, 6000), dtype='float').rechunk((2000,2000))
    images.append(arr)
    arr = da.from_delayed(cf, (8000, 6000), dtype='float').rechunk((2000,2000))
    conf.append(arr)

In [None]:
images[0]

In [None]:
stack = da.stack(images)
cstack = da.stack(conf)

In [None]:
stack

In [None]:
fig = plt.figure(figsize=(10,15))
plt.imshow(stack[8]);

In [None]:
fig = plt.figure(figsize=(10,15))
plt.imshow(stack.sum(axis=0));

In [None]:
fig = plt.figure(figsize=(10,15))
plt.imshow(stack.sum(axis=0)/cstack.sum(axis=0));

In [None]:
client.close()
cluster.close()