# Merfish to atlas registration

In [1]:
import numpy as np
%matplotlib notebook
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib import cm
from matplotlib.lines import Line2D
import os
import glob
import torch

from scipy.stats import rankdata
import nrrd
import time

import imp
import tools
imp.reload(tools)

  del sys.path[0]


<module 'tools' from '/ifshome/oamiuwu/STalign/tools.py'>

# Load atlas image

In [2]:
files = glob.glob('/ifshome/dtward/data/merfish/jean_fan_2021/OneDrive_1_8-5-2021/*metadata*.csv.gz')
files.sort()
fname = files[5]

df = pd.read_csv(fname)
xM = np.array(df['center_x'])
yM = np.array(df['center_y'])

fig,ax = plt.subplots()
ax.scatter(xM,yM,s=1,alpha=0.25)

X,Y,M,fig = tools.rasterize(xM,yM)
ax = fig.axes[0]

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

0 of 85958
10000 of 85958
20000 of 85958
30000 of 85958
40000 of 85958
50000 of 85958
60000 of 85958
70000 of 85958
80000 of 85958
85957 of 85958


# Load target image

In [10]:
atlas_file = "ara_nissl_25.nrrd"
J = nrrd.read(atlas_file)[0]
J = J[262:263,...] #get correct coronal slice
fig,ax = plt.subplots()
ax.imshow(J.transpose(1,2,0))

<IPython.core.display.Javascript object>

<matplotlib.image.AxesImage at 0x7fff63684fd0>

# Declare landmark points and plot them

In [5]:
pointsI = np.array([[231,146],[233,173],[166,166],[167,101],[182,227],[36,175],
                       [126,104],[141,231],[112,55],[121,282],[164,132],[169,201],[78,172]])*30.
pointsJ = np.array([[21,210],[21,242],[110,226],[93,156],[93,300],[281,226],
                   [159,139],[159,315],[176,73],[176,381],[109,201],[109,255],[220,226]])*25.

In [23]:
I = tools.normalize(M)
YI = np.array(range(I.shape[1]))*30.
XI = np.array(range(I.shape[2]))*30.

J = tools.normalize(J).astype(I.dtype)
YJ = np.array(range(J.shape[1]))*25.
XJ = np.array(range(J.shape[2]))*25.

extentI = tools.extent_from_x((YI,XI))
extentJ = tools.extent_from_x((YJ,XJ))

In [8]:
fig,ax = plt.subplots(1,2)
ax[0].imshow((I/I.max()).transpose(1,2,0), extent=extentI)
ax[1].imshow((J).transpose(1,2,0), extent=extentJ)

ax[0].scatter(pointsI[:,1],pointsI[:,0])
ax[1].scatter(pointsJ[:,1],pointsJ[:,0])
for i in range(pointsI.shape[0]):
    ax[0].text(pointsI[i,1],pointsI[i,0],f'{i}')
    ax[1].text(pointsJ[i,1],pointsJ[i,0],f'{i}')

<IPython.core.display.Javascript object>

# Run the mapping code

In [9]:
imp.reload(tools)
L,T = tools.L_T_from_points(pointsI,pointsJ)
device = 'cpu'
params = {'L':L,'T':T,
          'pointsI':pointsI,
          'pointsJ':pointsJ,
          'niter':1000,
          'device':device,
          'sigmaM':0.2
          }

A,v,xv = tools.LDDMM([YI,XI],I,[YJ,XJ],J,**params)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# Apply transform to atlas image

In [20]:
levels = np.arange(-100000,100000,1000)

fig,ax = plt.subplots()
phii = tools.build_transform(xv,v,A,XJ=[YJ,XJ],direction='b')
phiI = tools.transform_image_atlas_to_target(xv,v,A,[YI,XI],I,[YJ,XJ])
phipointsI = tools.transform_points_atlas_to_target(xv,v,A,pointsI)

ax.contour(XJ,YJ,phii[...,0],colors='r',linestyles='-',levels=levels)
ax.contour(XJ,YJ,phii[...,1],colors='g',linestyles='-',levels=levels)
ax.set_aspect('equal')
ax.set_title('Atlas to target')

ax.imshow(phiI.permute(1,2,0)/torch.max(phiI),extent=extentJ)
ax.scatter(phipointsI[:,1].detach(),phipointsI[:,0].detach(),c="m")

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x7fff4f512890>

# Apply transform to target image

In [24]:
levels = np.arange(-100000,100000,1000)

fig,ax = plt.subplots()
phi = tools.build_transform(xv,v,A,XJ=[YI,XI],direction='f')
phiiJ = tools.transform_image_target_to_atlas(xv,v,A,[YJ,XJ],J,[YI,XI])
phiipointsJ = tools.transform_points_target_to_atlas(xv,v,A,pointsJ)

ax.contour(XI,YI,phi[...,0],colors='r',linestyles='-',levels=levels)
ax.contour(XI,YI,phi[...,1],colors='g',linestyles='-',levels=levels)
ax.set_aspect('equal')
ax.set_title('Target to atlas')

ax.imshow(phiiJ.permute(1,2,0)/torch.max(phiiJ),extent=extentI)
ax.scatter(phiipointsJ[:,1].detach(),phiipointsJ[:,0].detach(),c="m")

<IPython.core.display.Javascript object>

<matplotlib.collections.PathCollection at 0x7fff4fa40fd0>