# Demonstration of MR reconstruction with CCP PET-MR Software

This demonstration shows how to hande undersampled data
and how to write a simple iterative reconstruction algorithm with
the acquisition model.

This demo is a 'script', i.e. intended to be run step by step in a
Python notebook such as Jupyter. It is organised in 'cells'. Jupyter displays these
cells nicely and allows you to run each cell on its own.

First version: 27th of March 2019
Author: Johannes Mayer

CCP PETMR Synergistic Image Reconstruction Framework (SIRF).  
Copyright 2015 - 2017 Rutherford Appleton Laboratory STFC.  
Copyright 2015 - 2017 University College London.  
Copyright 2015 - 2017 Physikalisch-Technische Bundesanstalt.

This is software developed for the Collaborative Computational
Project in Positron Emission Tomography and Magnetic Resonance imaging
(http://www.ccppetmr.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
#%% make sure figures appears inline and animations works
%matplotlib notebook

# Setup the working directory for the notebook
import notebook_setup

In [None]:
__version__ = '0.1.0'

# import engine module
import sirf.Gadgetron as pMR
from sirf.Utilities import examples_data_path
from sirf_exercises import exercises_data_path


from cil.framework import  AcquisitionGeometry, BlockDataContainer, BlockGeometry
from cil.optimisation.functions import Function, OperatorCompositionFunction, SmoothMixedL21Norm, L1Norm, L2NormSquared, BlockFunction, MixedL21Norm, IndicatorBox, TotalVariation, LeastSquares, ZeroFunction
from cil.optimisation.operators import GradientOperator, BlockOperator, ZeroOperator, CompositionOperator,LinearOperator
from cil.optimisation.algorithms import PDHG, FISTA, GD
from cil.plugins.ccpi_regularisation.functions import FGP_TV

# import further modules
import os
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.animation as animation


In [None]:
pname = '/media/sf_CCP/mcir_phantom/SIRF/'
fname = 'RPE_MotionPhantom.h5'
fname_new = 'RPE_MotionPhantom_first70rpe.h5'


In [None]:
'''
Load in data and calculate coil sensitivity maps
'''
# %% GO TO MR FOLDER
pMR.AcquisitionData.set_storage_scheme('memory')

acq_data = pMR.AcquisitionData(pname + fname_new)
#acq_data = pMR.preprocess_acquisition_data(acq_data)
#acq_data = pMR.set_grpe_trajectory(acq_data)
acq_data.sort_by_time()

# Add dcf
#kdcf = pMR.compute_kspace_density(acq_data)

In [None]:
pe_ky = acq_data.get_ISMRMRD_info('kspace_encode_step_1')
#pe_kz = acq_data.get_ISMRMRD_info('kspace_encode_step_2')

In [None]:
import scipy.signal as sp_signal

# acquisition_time_stamp

ky_idx = np.where(pe_ky == (np.max(pe_ky)+1)//2)

# Get k-space as array
acq_data_arr = acq_data.as_array()

print(acq_data_arr.shape)

# Keep only points which have been acquired in the k-space centre (i.e. kx == 0 & ky == 0)
acq_data_arr = acq_data_arr[ky_idx[0], :, :]

self_nav = np.abs(np.squeeze(acq_data_arr[:,3,64]))
self_nav[0] = self_nav[1]
self_nav = sp_signal.medfilt(self_nav, 7)

# Interpolate self navigator to all PE numbers
self_nav = np.interp(np.linspace(0, len(pe_ky)-1, len(pe_ky)), ky_idx[0], self_nav)

# Sort navigator and obtain index
nav_idx = np.argsort(self_nav)

flag_amp_gating = False

# Bin data into Nms motion states each with the same amount of data
Nms = 4
num_pe_per_ms = np.ceil(len(pe_ky) / Nms).astype(np.int)
acq_idx_ms = []

for nnd in range(Nms):
    if flag_amp_gating:
        ms_begin = nnd * motion_amplitude/Nms + np.min(self_nav)
        ms_end = ms_begin + motion_amplitude/Nms

        if nnd < Nms - 2:
            cidx = np.where((self_nav >= ms_begin) & (self_nav < ms_end))
        else:
            cidx = np.where((self_nav >= ms_begin) & (self_nav <= ms_end))  

        acq_idx_ms.append(nav_idx[cidx])
    else:
        if nnd < Nms - 1:
            acq_idx_ms.append(nav_idx[nnd*num_pe_per_ms:(nnd+1)*num_pe_per_ms])
        else:
            acq_idx_ms.append(nav_idx[nnd*num_pe_per_ms:])



In [None]:
plt.figure()
plt.plot(self_nav, '-k')
for ind in range(Nms):
    print(ind, ' - ', len(acq_idx_ms[ind]))
    plt.plot(acq_idx_ms[ind], self_nav[acq_idx_ms[ind]], 'o')

In [None]:
csm = pMR.CoilSensitivityData()
csm.smoothness = 100
csm.calculate(acq_data)


In [None]:
csm_arr = csm.as_array()
fig, ax = plt.subplots(1,3)
ax[0].imshow(np.abs(csm_arr[2, 102, :, :]))
ax[1].imshow(np.abs(csm_arr[2, :, 64, :]))
ax[2].imshow(np.abs(csm_arr[2, :, :, 64]))

In [None]:
# Go through motion states, create corresponding k-space and reconstruct images

num_ms = Nms

acq_ms = [0] * num_ms
im_ms = [0] * num_ms
E_ms = [0] * num_ms

# Apply kdcf
#acq_data *= kdcf

num_ms = Nms
acq_idx_sel = acq_idx_ms

#acq_idx_ref = np.load(pname + 'resp_idx_mcir.npy', allow_pickle=True)
#acq_idx_sel = acq_idx_ref
num_ms = len(acq_idx_sel)

fig, ax = plt.subplots(3, num_ms)
plt.setp(ax, xticks=[], yticks=[])
for ind in range(num_ms):
    
    if True:
        acq_ms[ind] = acq_data.new_acquisition_data(empty=True)

        # Add motion resolved data
        for jnd in range(len(acq_idx_sel[ind])):
            cacq = acq_data.acquisition(acq_idx_sel[ind][jnd])
            acq_ms[ind].append_acquisition(cacq)
    else:
        acq_ms[ind] = acq_data.get_subset(acq_idx_sel[ind])
        
    acq_ms[ind].sort_by_time()
        
    # Create acquisition model
    E_tmp = pMR.AcquisitionModel(acqs=acq_ms[ind], imgs=csm)
    E_tmp.set_coil_sensitivity_maps(csm)

    #im_ms[ind] = E_tmp.adjoint(acq_ms[ind])
    im_ms[ind] = E_tmp.inverse(acq_ms[ind])

    E_ms[ind] = pMR.AcquisitionModel(acqs=acq_ms[ind], imgs=im_ms[ind])
    E_ms[ind].set_coil_sensitivity_maps(csm)
    
    rec_im_arr = im_ms[ind].as_array()
    ax[0, ind].imshow(np.abs(rec_im_arr[102, :, :]))
    ax[0, ind].plot([32, 32], [0, 130], '-w')
    ax[1, ind].imshow(np.abs(rec_im_arr[:, 64, :]))
    ax[1, ind].plot([32, 32], [0, 200], '-w')
    ax[2, ind].imshow(np.abs(rec_im_arr[:, :, 50]))


In [None]:
# Create acquisition model
E = pMR.AcquisitionModel(acqs=acq_data, imgs=csm)
E.set_coil_sensitivity_maps(csm)

# Pseudo-inverse
rec_im = E.inverse(acq_data)

In [None]:
rec_im_arr = rec_im.as_array()

fig, ax = plt.subplots(1,3)
ax[0].imshow(np.abs(rec_im_arr[102, :, :]))
ax[1].imshow(np.abs(rec_im_arr[:, 64, :]))
ax[2].imshow(np.abs(rec_im_arr[:, :, 64]))

In [None]:
import time

E = pMR.AcquisitionModel(acqs=acq_data, imgs=rec_im)
E.set_coil_sensitivity_maps(csm)


num_it_fista = 20
x_init = rec_im.clone()

t1 = time.time()
f = LeastSquares(E, acq_data, c=1)
print('LS {:3.2f}s'.format((time.time() - t1)))

G = ZeroFunction()

# alpha = 0.01
# G = alpha * FGP_TV(max_iteration=10, device='cpu')

# Run FISTA for least squares
t1 = time.time()
fista = FISTA(x_init=x_init, f=f, g=G)
fista.max_iteration = num_it_fista
fista.update_objective_interval = 1
print('SETUP {:3.2f}s'.format((time.time() - t1)))

t1 = time.time()
fista.run(100, verbose=True)
print('FISTA {:3.2f}s'.format((time.time() - t1)))

In [None]:
rec_im_arr = fista.get_output().as_array()

fig, ax = plt.subplots(1,3)
ax[0].imshow(np.abs(rec_im_arr[102, :, :]))
ax[1].imshow(np.abs(rec_im_arr[:, 64, :]))
ax[2].imshow(np.abs(rec_im_arr[:, :, 64]))

In [None]:
rec_ms_fista = [0] * num_ms

for ind in range(num_ms):

    num_it_fista = 10
    x_init = im_ms[ind].clone()

    t1 = time.time()
    f = LeastSquares(E_ms[ind], acq_ms[ind], c=1)
    print('LS {:3.2f}s'.format((time.time() - t1)))

    G = ZeroFunction()

    # alpha = 0.01
    # G = alpha * FGP_TV(max_iteration=10, device='cpu')

    # Run FISTA for least squares
    t1 = time.time()
    fista = FISTA(x_init=x_init, f=f, g=G)
    fista.max_iteration = num_it_fista
    fista.update_objective_interval = 1
    print('SETUP {:3.2f}s'.format((time.time() - t1)))

    t1 = time.time()
    fista.run(100, verbose=True)
    print('FISTA {:3.2f}s'.format((time.time() - t1)))
    
    rec_ms_fista[ind] = fista.get_output()
    


In [None]:

fig, ax = plt.subplots(3, num_ms)
plt.setp(ax, xticks=[], yticks=[])
for ind in range(num_ms):

    rec_im_arr = rec_ms_fista[ind].as_array()
    ax[0, ind].imshow(np.abs(rec_im_arr[102, :, :]))
    ax[0, ind].plot([32, 32], [0, 130], '-w')
    ax[1, ind].imshow(np.abs(rec_im_arr[:, 64, :]))
    ax[1, ind].plot([32, 32], [0, 200], '-w')
    ax[2, ind].imshow(np.abs(rec_im_arr[:, :, 50]))

In [None]:
im_ms_rec = []
for ind in range(num_ms):
    im_ms_rec.append(rec_ms_fista[ind].abs())

In [None]:
'''
Register different motion gates
'''

import sirf.Reg as pReg


# Forward motion fields
mf_resampler = [0] * num_ms
im_res = [0] * num_ms
im_corr = [0] * num_ms
for ind in range(num_ms):
    #algo = pReg.NiftyF3dSym()
    algo = pReg.NiftyAladinSym()

    # Set up images
    algo.set_reference_image(pReg.NiftiImageData3D(im_ms_rec[ind])) # remove NiftiImageData3D?????
    algo.set_floating_image(pReg.NiftiImageData3D(im_ms_rec[0]))

    algo.process()
    reg_result = algo.get_output()

    mf_forward = algo.get_deformation_field_forward()


    # Create resampler
    mf_resampler[ind] = pReg.NiftyResample()
    mf_resampler[ind].set_reference_image(rec_ms_fista[ind])
    mf_resampler[ind].set_floating_image(rec_ms_fista[ind])
    mf_resampler[ind].add_transformation(mf_forward)
    mf_resampler[ind].set_padding_value(0)
    mf_resampler[ind].set_interpolation_type_to_linear()

    im_res[ind] = mf_resampler[ind].forward(rec_ms_fista[0])
    im_corr[ind] = mf_resampler[ind].backward(rec_ms_fista[ind])



In [None]:


fig, ax = plt.subplots(3, num_ms)
plt.setp(ax, xticks=[], yticks=[])
for ind in range(num_ms):    
    rec_im_arr = im_res[ind].as_array()
    rec_im_arr /= rec_im_arr.max()
    ms_im_arr = im_ms_rec[ind].as_array()
    ms_im_arr /= ms_im_arr.max()
    ax[0, ind].imshow(np.abs(rec_im_arr[:, 64, :]), vmin=0, vmax=1)
    ax[1, ind].imshow(np.abs(ms_im_arr[:, 64, :]), vmin=0, vmax=1)
    ax[2, ind].imshow(np.abs(rec_im_arr[:, 64, :]) - np.abs(ms_im_arr[:, 64, :]), vmin=0, vmax=1)

In [None]:
# RTA
im_orig = rec_ms_fista[0]
im_rta = im_corr[0]
for ind in range(1,num_ms):
    im_orig += rec_ms_fista[ind]
    im_rta += im_corr[ind]
    
fig, ax = plt.subplots(2,3)
ax[0,0].imshow(np.abs(im_orig.as_array()[102, :, :]))
ax[0,1].imshow(np.abs(im_orig.as_array()[:, 64, :]))
ax[0,2].imshow(np.abs(im_orig.as_array()[:, :, 64]))

ax[1,0].imshow(np.abs(im_rta.as_array()[102, :, :]))
ax[1,1].imshow(np.abs(im_rta.as_array()[:, 64, :]))
ax[1,2].imshow(np.abs(im_rta.as_array()[:, :, 64]))

Fix

~/devel/install/python/cil/optimisation/operators/Operator.py in PowerMethod(operator, iterations, x_init)
    145             x1norm = x1.norm()
    146             if hasattr(x0, 'squared_norm'):
--> 147                 s[it] =numpy.abs( x1.dot(x0) / x0.squared_norm())
    148             else:
    149                 x0norm = x0.norm()

TypeError: can't convert complex to float


In [None]:
# Set up reconstruction
C = [CompositionOperator(am, res) for am, res in zip(*(E_ms, mf_resampler))]
A = BlockOperator(*C)

# Initial pseudo inverse
acq_ms_block = BlockDataContainer(*acq_ms)
im_xinit = A.adjoint(acq_ms_block)

num_it_fista = 1
f = LeastSquares(A, acq_ms_block, c=1)

reg_mcir_fista = None
if reg_mcir_fista == 'tv':
    G = cilPluginToSIRFFactory.getInstance(FGP_TV, lambdaReg=1e-8, iterationsTV=10,
                                           tolerance=1e-7, methodTV=0, nonnegativity=0,
                                           printing=1, device='cpu')

elif reg_mcir_fista == 'tgv':
    alpha = 1.
    beta = alpha * 2
    lip_const = 12.
    G = cilPluginToSIRFFactory.getInstance(TGV, regularisation_parameter=.01,
                                           LipshitzConstant=lip_const,
                                           alpha1=alpha, alpha2=beta,
                                           iter_TGV=10, torelance=1e-4,
                                           device='cpu')

elif reg_mcir_fista == None:
    G = ZeroFunction()
else:
    assert 0, 'reg_mcir_fista should be None, tv or tgv'

# Run FISTA for least squares
fista = FISTA(x_init=im_xinit, f=f, g=G)
fista.max_iteration = num_it_fista
fista.update_objective_interval = 1
fista.run(10, verbose=True)



In [None]:
rec_im_arr = fista.get_output().as_array()

fig, ax = plt.subplots(1,3)
ax[0].imshow(np.abs(rec_im_arr[102, :, :]))
ax[1].imshow(np.abs(rec_im_arr[:, 64, :]))
ax[2].imshow(np.abs(rec_im_arr[:, :, 64]))