# Motion-Corrected Image Reconstruction (MCIR)

In [None]:
# Setup the working directory for the notebook
import notebook_setup

In [None]:
__version__ = '0.1.0'

# import engine module
import sirf.Gadgetron as pMR
import sirf.Reg as pReg

from cil.framework import  AcquisitionGeometry, BlockDataContainer
from cil.optimisation.functions import Function, L2NormSquared, BlockFunction, LeastSquares, ZeroFunction
from cil.optimisation.operators import BlockOperator, CompositionOperator, LinearOperator
from cil.optimisation.algorithms import FISTA
from sirf.Reg import NiftyResampler 

# import further modules
import os
import numpy as np
import scipy.signal as sp_signal

In [None]:
# Load in the data
path = '/home/jovyan/devel/MR/RPE_MotionPhantom_first70rpe.h5'
acq_data = pMR.AcquisitionData(path)
acq_data.sort_by_time()

## Standard (i.e. uncorrected) image reconstruction

In [None]:
# Calculate coil sensitivity maps
csm = pMR.CoilSensitivityData()
csm.smoothness = 100
csm.calculate(acq_data)

# Motion

## Motion surrogate

In [None]:
# Get the information about the index of the phase encoding
pe_ky = acq_data.get_ISMRMRD_info('kspace_encode_step_1')

# Find the central value of each GRPE line (i.e. ky=0)
ky_idx = np.where(pe_ky == (np.max(pe_ky)+1)//2)

# Get the k-space as array
acq_data_arr = acq_data.as_array()

# Keep only points which have been acquired in the k-space centre (i.e. ky == 0 & kz == 0)
acq_data_arr = acq_data_arr[ky_idx[0], :, :]

# Select the last coil and the centre of each readout (i.e kx = 0) to get our final 1D motion signal
self_nav = np.abs(np.squeeze(acq_data_arr[:,3,acq_data_arr.shape[2]//2]))

# We said above that the value of the k-space centre should only vary because of motion. 
# This is not entirely true, because it can also vary because of other effects. One is that the spin system
# is in a transient steady state at the beginning of the data acquisition. 
# Therefore, we will simply overwrite the first entry in our motion signal
self_nav[0] = self_nav[1]

# Do some filtering
self_nav = sp_signal.medfilt(self_nav, 7)

# Interpolate self navigator to all phase encoding points
self_nav = np.interp(np.linspace(0, len(pe_ky)-1, len(pe_ky)), ky_idx[0], self_nav)

# Sort navigator and obtain index
nav_idx = np.argsort(self_nav)

# Bin data into Nms motion states each with the same amount of data
Nms = 4
num_pe_per_ms = np.ceil(len(pe_ky) / Nms).astype(np.int64)
acq_idx_ms = []

for nnd in range(Nms):
    if nnd < Nms - 1:
        acq_idx_ms.append(nav_idx[nnd*num_pe_per_ms:(nnd+1)*num_pe_per_ms])
    else:
        acq_idx_ms.append(nav_idx[nnd*num_pe_per_ms:])



## Motion Resolved Images

In [None]:
# Go through each motion states, create corresponding k-space and acquisition model
acq_ms = [0] * Nms
E_ms = [0] * Nms


for ind in range(Nms):
    
    acq_ms[ind] = acq_data.get_subset(acq_idx_ms[ind])        
    acq_ms[ind].sort_by_time()
        
    # Create acquisition model
    E_tmp = pMR.AcquisitionModel(acqs=acq_ms[ind], imgs=csm)
    E_tmp.set_coil_sensitivity_maps(csm)
    im_ms = E_tmp.inverse(acq_ms[ind])

    E_ms[ind] = pMR.AcquisitionModel(acqs=acq_ms[ind], imgs=im_ms)
    E_ms[ind].set_coil_sensitivity_maps(csm)

Now we can reconstruct each motion state:

In [None]:
im_fista_ms = [0] * Nms

for ind in range(Nms):

    # Starting image
    x_init = im_ms.clone()
    x_init.fill(0.0)

    # Objective function
    f = LeastSquares(E_ms[ind], acq_ms[ind], c=1)
    G = ZeroFunction()

    # Set up FISTA for least squares
    fista = FISTA(initial=x_init, f=f, g=G)
    fista.max_iteration = 100
    fista.update_objective_interval = 5

    # Run FISTA
    fista.run(10, verbose=True)
    
    # Get result
    im_fista_ms[ind] = fista.get_output()

## Estimate Motion Vector fields

In [None]:
im_fista_ms_abs = []
for ind in range(Nms):
    im_fista_ms_abs.append(im_fista_ms[ind].abs())

In [None]:
# Motion fransformation object
mf_resampler = [0] * Nms

# Forward transformation (i.e. reference image transformed to current motion state)
im_forward = [0] * Nms

# Backward transformation (i.e. current motion image transformed to reference motion state)
im_backward = [0] * Nms


for ind in range(Nms):

    # Affine image registration
    algo = pReg.NiftyAladinSym()

    # Set up images
    ref = pReg.NiftiImageData3D(im_fista_ms_abs[ind])
    flo = pReg.NiftiImageData3D(im_fista_ms_abs[0])
    algo.set_reference_image(ref)
    algo.set_floating_image(flo)

    # Run registration    
    algo.process()

    # Get forward deformation 
    mf_forward = algo.get_deformation_field_forward()

    # Create resampler
    mf_resampler[ind] = NiftyResampler()
    mf_resampler[ind].set_reference_image(ref)
    mf_resampler[ind].set_floating_image(flo)
    mf_resampler[ind].add_transformation(mf_forward)
    mf_resampler[ind].set_padding_value(0)
    mf_resampler[ind].set_interpolation_type_to_linear()

# do not seem to be used
#    im_forward[ind] = mf_resampler[ind].forward(im_fista_ms[0])
#    im_backward[ind] = mf_resampler[ind].backward(im_fista_ms[ind])


# MCIR norms

In [None]:
# This is the part of code that computes the norm of the resampler
# as the largest singular value of its direct operator
# (i) by C++ implementation of Jacobi-Conjugate Gradient method interfaced into Python by SIRF
# (ii) as Python implementation of Power Method

import time


# (i)
print("computing nifti resampler norms by CG...")
nifti_resamplers_norms = []
start = time.time()
i = 0
for resampler_i in mf_resampler:
    rni = resampler_i.norm(10, 0)
    print("resampler %d norm: %f" % (i, rni))
    nifti_resamplers_norms.append(rni)
    i += 1
print('%f sec' % (time.time() - start))


# (ii)
class MyLinearOperator(LinearOperator):
    
    # Redefining the PowerMethod (which was giving issues for computing the norm).
    
    def PowerMethod(operator, max_iteration=10, initial=None, tolerance = 1e-5,  return_all=False):

        symmetric = False
        try:
            if operator.domain_geometry()==operator.range_geometry():
                symmetric = True
        except AssertionError:
            pass

        if initial is None:
            x0 = operator.domain_geometry().allocate('random')
        else:
            x0 = initial.copy()

        y_tmp = operator.range_geometry().allocate()
  
        x0_norm = x0.norm()
        x0 /= x0_norm
        
        eig_old = 1.
        eig_list = []
        diff = np.finfo('d').max
        
        i=0
        while (i < max_iteration and diff > tolerance):
            i+=1
            operator.direct(x0, out = y_tmp)            
            operator.adjoint(y_tmp,out=x0)
            
            x0_norm = x0.norm()      
            x0 /= x0_norm
            eig_new =  np.abs(x0_norm)
            if not symmetric:
                eig_new = np.sqrt(eig_new)

            eig_list.append(eig_new)
            eig_old = eig_new      
        
        if return_all:
            return eig_new, i, x0, eig_list
        else:
            return eig_new
        
    def calculate_norm(self):
        
        r""" Returns the norm of the LinearOperator calculated by the PowerMethod with default values.
                """
        return MyLinearOperator.PowerMethod(self)


def NiftyResampler_norm(x):
    out = MyLinearOperator.calculate_norm(x)
    return out


setattr(NiftyResampler, 'norm', NiftyResampler_norm)


print("computing nifti resampler norms by power method...")
nifti_resamplers_norms = []
start = time.time()
i = 0
for resampler_i in mf_resampler:
    rni = resampler_i.norm()
    print("resampler %d norm: %f" % (i, rni))
    nifti_resamplers_norms.append(rni)
    i += 1
#    nifti_resamplers_norms.append(resampler_i.norm())
print('%f sec' % (time.time() - start))


In MCIR, because the resampler forms part of a composition operator,
the norm of the resampler is computed when calling the following function,
which calculates the norm of the compositon operator.

In [None]:
def CompositionOperator_norm(comp_operator):
    out = 1.
    for factor in comp_operator.operators:
        out = out * factor.norm()
    return out

setattr(CompositionOperator, 'norm', CompositionOperator_norm)

In [None]:
C = [CompositionOperator(am, res) for am, res in zip(*(E_ms, mf_resampler))]

## Norm values 

Because I'm using 4 motion states, we'll have 4 nifti resampler norms, and therefore 4 composition operator norms.

In [None]:
# nifti resampler norms - already computed above
#nifti_resamplers_norms = []
#for resampler_i in mf_resampler:
#    nifti_resamplers_norms.append(resampler_i.norm())

In [None]:
nifti_resamplers_norms

In [None]:
# composition operators norm
comp_operators_norms = []
for comp_operator_i in C:
    comp_operators_norms.append(comp_operator_i.norm())

In [None]:
comp_operators_norms

In [None]:
### TODO: THE REST SHOULD BE UPDATED!

## Norm values that you should get

Nifti resamplers norms I get by using the code above: 

[0.9999986886978149,
 1.0106359720230103,
 0.9894781112670898,
 1.0005484819412231]

The values in the list nifti_resamplers_norms (which you get above, in the section "Norm values") should be similar to these ones, i.e. very close to 1.

Composition operator norms I get by using the code above:
    
[24.53206259506578, 25.586999705435574, 24.64181221150693, 24.65180460950478]

Like before, the values in the list comp_operators_norms (again in the section "Norm values") should also be similar to the ones listed above. I wrote some tests below where I ask for them to be equal up to the first decimal place. Hopefully the the values will pass the test when you use your function.

# Tests

In [None]:
import unittest

correct_resampler_norms = [0.9999989867210388, 1.0106064081192017, 0.9894943237304688, 1.0005548000335693]
correct_comp_op_norms = [24.53206844396982, 25.586978579076458, 24.641785492754025, 24.652982414194184]
  
class TestNorms(unittest.TestCase):

    def test_resampler_norm(self, n_resampler):
        first_value = nifti_resamplers_norms[n_resampler]
        second_value = correct_resampler_norms[n_resampler]
        decimal_place = 1
        self.assertAlmostEqual(first_value, second_value, decimal_place)
        
    def test_comp_op_norm(self, n_comp_op):
        first_value = comp_operators_norms[n_comp_op]
        second_value = correct_comp_op_norms[n_comp_op]
        decimal_place = 1
        self.assertAlmostEqual(first_value, second_value, decimal_place)   

In [None]:
example = TestNorms()

for i in range(Nms):
    example.test_resampler_norm(i)
    example.test_comp_op_norm(i)