# LDDMM Examples

In this file we run image matching LDDMM on several examples, and demonstrate the utilities in lddmm.py.

The only special functions LDDMM needs to run are

1. Interpolation
1. FFT
1. Gradient



## Library imports
We start by importing necessary libraries.  That includes numpy, matplotlib, and tensorflow for numerical work, nibabel for loading neuroimages, and lddmm and vis which are part of this library.

In [2]:
import numpy as np # for arrays
%matplotlib notebook
import matplotlib as mpl # for graphics
import matplotlib.pyplot as plt
import nibabel as nib # for loading neuroimages
import lddmm # algorithm
import vis # visualization
import tensorflow as tf

Importing helper functions
importing vis


## Development
During development, we have to reimport any libraries we are working on.

In [3]:
# for development
import imp
vis = imp.reload(vis)
lddmm = imp.reload(lddmm)

importing vis
Importing helper functions


## Example data
We will start with some human MRI as example data.  These are two atlases from mricloud.org.  We will deform the atlas to match the target.

In [None]:
# get filenames
atlas_image_fname = 'Adt27-55_02_Adt27-55_02_MNI.img'
target_image_fname = 'Adt27-55_03_Adt27-55_03_MNI.img'

In [None]:
# load them with nibabel
fnames = [atlas_image_fname,target_image_fname]
img = [nib.load(fname) for fname in fnames]

In [None]:
# get info about domains
# we assume for this example that we have the same voxel size and same voxel spacing for atlas and target
if '.img' == atlas_image_fname[-4:]:    
    nx = img[0].header['dim'][1:4]
    dx = img[0].header['pixdim'][1:4]
else:
    # I'm only working with analyze for now
    raise ValueError('Only Analyze images supported for now')
x = [np.arange(nxi)*dxi - np.mean(np.arange(nxi)*dxi) for nxi,dxi in zip(nx,dx)]


In [None]:
# get the images, note they also include a fourth axis for time that I don't want
I = img[0].get_data()[:,:,:,0]
J = img[1].get_data()[:,:,:,0]
# in this example, images are the same size
# this can be enfored after affine registration
# In the future, we can implement different sizes

In [None]:
# display the data
f = plt.figure()
vis.imshow_slices(I,x=x,fig=f)
f.suptitle('Atlas I')
f.canvas.draw()

In [None]:
f = plt.figure()
vis.imshow_slices(J,x=x,fig=f)
f.suptitle('Target J')
f.canvas.draw()

## Interpolation
Deformations are compted by interpolating an image at a set of points.  We demonstrate interpolation here:

In [None]:
do_interp_test = False
do_interp_test = True
if do_interp_test:
    lddmm = imp.reload(lddmm)
    vis = imp.reload(vis)
    # a quick test
    X0,X1,X2 = np.meshgrid(x[0],x[1],x[2],indexing='ij')
    X0tf = tf.constant(X0,dtype=lddmm.dtype)
    X1tf = tf.constant(X1,dtype=lddmm.dtype)
    X2tf = tf.constant(X2,dtype=lddmm.dtype)
    Itf = tf.constant(I,dtype=lddmm.dtype)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        Id = lddmm.interp3(x[0],x[1],x[2],Itf,X0tf+10,X1tf*1.2,X2tf + X2tf**2*0.005)
        Idnp = Id.eval()
    f = plt.figure()
    vis.imshow_slices(Idnp,x=x,fig=f)
    f.suptitle('Example of transforming by interpolation')
    f.canvas.draw()
    

## Gradient
Gradients are required to know how the energy will change when the image moves a small amount.  We demonstrate this here

In [None]:
do_grad_test = False
do_grad_test = True
if do_grad_test:
    # a quick test    
    lddmm = imp.reload(lddmm)
    vis = imp.reload(vis)
    Itf = tf.constant(I,dtype=lddmm.dtype)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        I_0,I_1,I_2 = lddmm.grad3(Itf,dx)
        Ishow = tf.sqrt(I_0**2 + I_1**2 + I_2**2)
        Ishownp = Ishow.eval()
    f = plt.figure()        
    vis.imshow_slices(Ishownp,x=x,fig=f)
    f.canvas.draw()
    

## LDDMM

In [None]:
lddmm = imp.reload(lddmm)
out = lddmm.lddmm(I,J,niter=50,sigmaM=1e1,sigmaR=1e0,epsilon=1e-1,p=2,a=2)

## Example with mouse images


In [None]:
# get filenames
atlas_image_fname = 'PMD2052_orig_target_STS_clean.img'
target_image_fname = 'PMD3097_orig_target_STS_clean.img'

In [None]:
# load them with nababel
fnames = [atlas_image_fname,target_image_fname]
img = [nib.load(fname) for fname in fnames]

In [None]:
# get info about domains
# we assume for this example that we have the same voxel size and same voxel spacing for atlas and target
if '.img' == atlas_image_fname[-4:]:    
    nx = img[0].header['dim'][1:4]
    dx = img[0].header['pixdim'][1:4]
else:
    # I'm only working with analyze for now
    raise ValueError('Only Analyze images supported for now')
x = [np.arange(nxi)*dxi - np.mean(np.arange(nxi)*dxi) for nxi,dxi in zip(nx,dx)]

In [None]:
# get the images, note they also include a fourth axis for time that I don't want
I = img[0].get_data()[:,:,:,0]
J = img[1].get_data()[:,:,:,0]
# in this example, images are the same size
# this can be enfored after affine registration
# In the future, we can implement different sizes

In [None]:
# display the data
f = plt.figure()
vis.imshow_slices(I, x=x, fig=f)
f.suptitle('Atlas I')
f.canvas.draw()

In [None]:
f = plt.figure()
vis.imshow_slices(J,x=x,fig=f)
f.suptitle('Target J')
f.canvas.draw()

In [None]:
lddmm = imp.reload(lddmm)
vis = imp.reload(vis)

p = 2
sigmaM = 10.0
eT = 1e-5
eL = 5e-7
eV = 1e-3
naffine = 50
niter = 200
sigmaR = 2e-1


out = lddmm.lddmm(I, J, 
                  niter=niter, 
                  naffine=naffine,
                  eV = eV,
                  eT = eT,
                  eL = eL,
                  sigmaM=sigmaM, 
                  sigmaR=sigmaR,
                  xI=x,
                  xJ=x,
                  a=(x[0][1]-x[0][0])*2,
                  p=p)

## Example with mouse images and outlier voxels

In [62]:
# get filenames
atlas_image_fname = 'FluoroAtlas_Downsample.img'
atlas_image_fname = 'average_template_50.img'
target_image_fname = '180517_ch1_Downsample.img'

In [63]:
# load them with nababel
fnames = [atlas_image_fname,target_image_fname]
img = [nib.load(fname) for fname in fnames]

In [64]:
# get info about domains
# we assume for this example that we have the same voxel size and same voxel spacing for atlas and target
if '.img' == atlas_image_fname[-4:]:    
    nxI = img[0].header['dim'][1:4]
    dxI = img[0].header['pixdim'][1:4]
    nxJ = img[1].header['dim'][1:4]
    dxJ = img[1].header['pixdim'][1:4]
    # dx's are wrong in headers for this example
    
    #dxJ = [0.025,0.025,0.025];
    dxJ = [0.05,0.05,0.05]
else:
    # I'm only working with analyze for now
    raise ValueError('Only Analyze images supported for now')
xI = [np.arange(nxi)*dxi - np.mean(np.arange(nxi)*dxi) for nxi,dxi in zip(nxI,dxI)]
xJ = [np.arange(nxi)*dxi - np.mean(np.arange(nxi)*dxi) for nxi,dxi in zip(nxJ,dxJ)]


# NOTE that dx here is 1 because header data is incorrect, TO DO update headers

In [65]:
# get the images, note they also include a fourth axis for time that I don't want
I = img[0].get_data()[:,:,:,0]
J = img[1].get_data()[:,:,:,0]
# in this example, images are the same size
# this can be enfored after affine registration
# In the future, we can implement different sizes



In [None]:
# I would like to pad one slice of the allen atlas so that it has zero boundary conditions

nxI = 
xI = [np.arange(nxi)*dxi - np.mean(np.arange(nxi)*dxi) for nxi,dxi in zip(nxI,dxI)]

In [66]:
# display the data
f = plt.figure()
vis.imshow_slices(I, x=xI, fig=f)
f.suptitle('Atlas I')
f.canvas.draw()

<IPython.core.display.Javascript object>

In [28]:
f = plt.figure()
vis.imshow_slices(J,x=xJ,fig=f)
f.suptitle('Target J')
f.canvas.draw()

<IPython.core.display.Javascript object>

In [67]:
# let's quickly normalize the data
#I = I - np.mean(I)
#I = I / np.std(I)
#
#J = J - np.mean(J)
#J = J / np.std(J)

J = J - 0.0 # make it not be a memmap
Ibar = np.mean(I)
I0 = I - Ibar
Jbar = np.mean(J)
J0 = J - Jbar
VarI = np.mean(I0**2)
#CovIJ = np.mean(I0*J0) # okay this doesn't work because they are different sizes
#Iin = Iin/VarI*CovIJ
#Iin = Iin + np.mean(J) - np.mean(Iin)
VarJ = np.mean(J0**2)
Iin = I * np.sqrt(VarJ / VarI)
Iin = Iin + np.mean(J) - np.mean(Iin)

In [75]:
A = np.eye(4);
A = np.array([[1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]]).dot(A)
A = np.array([[0,0,1,0],[0,1,0,0],[1,0,0,0],[0,0,0,1]]).dot(A)
A = np.diag([-1,1,1,1]).dot(A)
# the line below is a good initial guess
A = np.array([[0,0,1,0],[-1,0,0,0],[0,1,0,0],[0,0,0,1]])
A

array([[ 0,  0,  1,  0],
       [-1,  0,  0,  0],
       [ 0,  1,  0,  0],
       [ 0,  0,  0,  1]])

In [76]:
# test the initial affine
X0,X1,X2 = np.meshgrid(xJ[0],xJ[1],xJ[2],indexing='ij')
X0tf = tf.constant(X0,dtype=lddmm.dtype)
X1tf = tf.constant(X1,dtype=lddmm.dtype)
X2tf = tf.constant(X2,dtype=lddmm.dtype)
Itf = tf.constant(I,dtype=lddmm.dtype)
B = np.linalg.inv(A)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    Xs = B[0,0]*X0tf + B[0,1]*X1tf + B[0,2]*X2tf + B[0,3]
    Ys = B[1,0]*X0tf + B[1,1]*X1tf + B[1,2]*X2tf + B[1,3]
    Zs = B[2,0]*X0tf + B[2,1]*X1tf + B[2,2]*X2tf + B[2,3]
    Id = lddmm.interp3(xI[0], xI[1], xI[2], Itf, Xs, Ys, Zs)
    Idnp = Id.eval()
f = plt.figure()
vis.imshow_slices(Idnp,x=xJ,fig=f)
f.suptitle('Initial affine transformation')
f.canvas.draw()


<IPython.core.display.Javascript object>

In [None]:
lddmm = imp.reload(lddmm)
vis = imp.reload(vis)

p = 2
sigmaM = np.std(J)
eT = 1e-5
eL = 5e-7
eV = 1e-3
naffine = 50
niter = 200
sigmaR = 2e-1

# these numbers from my matlab code
eL = 2e-4
eT = 1e-3
eV = 5e-3
sigmaR = 1e0
niter = 100
naffine = 50
nt = 5
# the linear part is a bit too big still



out = lddmm.lddmm(Iin, J, 
                  niter=niter, 
                  naffine=naffine,
                  eV = eV,
                  eT = eT,
                  eL = eL,
                  nt=nt,
                  sigmaM=sigmaM, 
                  sigmaR=sigmaR,
                  xI=xI,
                  xJ=xJ,
                  a=(xI[0][1]-xI[0][0])*5,
                  p=p,
                  A0=A)

Importing helper functions
importing vis
Set default parameters
initial affine transform [[ 0  0  1  0]
 [-1  0  0  0]
 [ 0  1  0  0]
 [ 0  0  0  1]]
Got parameters


<IPython.core.display.Javascript object>

Built energy operators
built tensorflow variables
Computation graph defined


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

Finished iteration 0, energy 5.237021e+02 (match 5.237021e+02, reg 0.000000e+00)
Finished iteration 1, energy 4.476383e+02 (match 4.476383e+02, reg 0.000000e+00)
Finished iteration 2, energy 4.085172e+02 (match 4.085172e+02, reg 0.000000e+00)
Finished iteration 3, energy 3.803023e+02 (match 3.803023e+02, reg 0.000000e+00)
Finished iteration 4, energy 3.607551e+02 (match 3.607551e+02, reg 0.000000e+00)
Finished iteration 5, energy 3.447175e+02 (match 3.447175e+02, reg 0.000000e+00)
Finished iteration 6, energy 3.300470e+02 (match 3.300470e+02, reg 0.000000e+00)
Finished iteration 7, energy 3.163517e+02 (match 3.163517e+02, reg 0.000000e+00)
Finished iteration 8, energy 3.037829e+02 (match 3.037829e+02, reg 0.000000e+00)
Finished iteration 9, energy 2.914098e+02 (match 2.914098e+02, reg 0.000000e+00)
Finished iteration 10, energy 2.787104e+02 (match 2.787104e+02, reg 0.000000e+00)
Finished iteration 11, energy 2.662872e+02 (match 2.662872e+02, reg 0.000000e+00)
Finished iteration 12, ene

In [34]:
out[0]

array([[-6.7742080e-03, -9.9624670e-01,  2.4537861e-04, -1.1723356e-02],
       [ 1.7252788e-03,  2.8352300e-03,  9.7003663e-01, -1.0234620e-03],
       [ 1.0080583e+00,  7.4772318e-03,  3.8079582e-03, -6.4482382e-03],
       [ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00,  1.0000000e+00]],
      dtype=float32)