# Joint TV for multi-contrast MR
This demonstration shows how to do image reconstruction using gradient descent for different modalities. 

It builds on the the notebook *acquisition_model_mr_pet_ct.ipynb*. The first part of the notebook which creates acquisition models and simulates data from the brainweb is the same code but with fewer comments. If anything is unclear, then please refer to the other notebook to get some further information.

This demo is a jupyter notebook, i.e. intended to be run step by step.
You could export it as a Python file and run it one go, but that might
make little sense as the figures are not labelled.


Author: Christoph Kolbitsch, Edoardo Pasca
First version: 23rd of April 2021  

CCP PETMR Synergistic Image Reconstruction Framework (SIRF).  
Copyright 2015 - 2017 Rutherford Appleton Laboratory STFC.  
Copyright 2015 - 2019 University College London.   
Copyright 2021 Physikalisch-Technische Bundesanstalt.

This is software developed for the Collaborative Computational
Project in Positron Emission Tomography and Magnetic Resonance imaging
(http://www.ccppetmr.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Initial set-up

In [None]:
# Make sure figures appears inline and animations works
%matplotlib notebook

In [None]:
# Make sure everything is installed that we need
!pip install brainweb nibabel --user

In [None]:
# Initial imports etc
import numpy
from numpy.linalg import norm
import matplotlib.pyplot as plt

import os
import sys
import shutil
import brainweb
from tqdm.auto import tqdm

import time

# Import MR functionality
import sirf.Gadgetron as mr

from sirf.Utilities import examples_data_path
from cil.framework import  AcquisitionGeometry, BlockDataContainer, BlockGeometry
from cil.optimisation.functions import Function, OperatorCompositionFunction, SmoothMixedL21Norm, L1Norm, L2NormSquared, BlockFunction, MixedL21Norm, IndicatorBox, TotalVariation, LeastSquares, ZeroFunction
from cil.optimisation.operators import GradientOperator, BlockOperator, ZeroOperator, CompositionOperator,LinearOperator
from cil.optimisation.algorithms import PDHG, FISTA, GD
from cil.plugins.ccpi_regularisation.functions import FGP_TV

# Utilities

In [None]:
# First define some handy function definitions
# To make subsequent code cleaner, we have a few functions here. You can ignore
# ignore them when you first see this demo.

def plot_2d_image(idx,vol,title,clims=None,cmap="viridis"):
    """Customized version of subplot to plot 2D image"""
    plt.subplot(*idx)
    plt.imshow(vol,cmap=cmap)
    if not clims is None:
        plt.clim(clims)
    plt.colorbar()
    plt.title(title)
    plt.axis("off")

def crop_and_fill(templ_im, vol):
    """Crop volumetric image data and replace image content in template image object"""
    # Get size of template image and crop
    idim_orig = templ_im.as_array().shape
    idim = (1,)*(3-len(idim_orig)) + idim_orig
    offset = (numpy.array(vol.shape) - numpy.array(idim)) // 2
    vol = vol[offset[0]:offset[0]+idim[0], offset[1]:offset[1]+idim[1], offset[2]:offset[2]+idim[2]]
    
    # Make a copy of the template to ensure we do not overwrite it
    templ_im_out = templ_im.copy()
    
    # Fill image content 
    templ_im_out.fill(numpy.reshape(vol, idim_orig))
    return(templ_im_out)

# Get brainweb data

We will download and use data from the brainweb. We will use a FDG image for PET and the PET uMAP for CT. MR usually provides qualitative images with an image contrast proportional to difference in T1, T2 or T2* depending on the sequence parameters. Nevertheless, we will make our life easy, by directly using the T1 map provided by the brainweb for MR.

In [None]:
fname, url= sorted(brainweb.utils.LINKS.items())[0]
files = brainweb.get_file(fname, url, ".")
data = brainweb.load_file(fname)

brainweb.seed(1337)

In [None]:
for f in tqdm([fname], desc="mMR ground truths", unit="subject"):
    vol = brainweb.get_mmr_fromfile(f, petNoise=1, t1Noise=0.75, t2Noise=0.75, petSigma=1, t1Sigma=1, t2Sigma=1)

In [None]:
T2_arr  = vol['T2']
T1_arr   = vol['T1']

# Normalise image data
T2_arr /= numpy.max(T2_arr)
T1_arr /= numpy.max(T1_arr)

In [None]:
# Display it
plt.figure();
slice_show = T2_arr.shape[0]//2
plot_2d_image([1,2,1], T2_arr[slice_show, 100:-100, 100:-100], 'T2', cmap="Greys_r")
plot_2d_image([1,2,2], T1_arr[slice_show, 100:-100, 100:-100], 'T1', cmap="Greys_r")

# Acquisition Models

Here we will set up the acquisition models for __MR__.

## MR

In [None]:
# 1. create MR AcquisitionData
mr_acq = mr.AcquisitionData(examples_data_path('MR') + '/grappa2_1rep.h5')
mr_acq = mr.AcquisitionData(examples_data_path('MR') + '/simulated_MR_2D_cartesian_Grappa2.h5')
#mr_acq = mr.AcquisitionData('/home/sirfuser/devel/SIRF-Exercises/data/MR/PTB_ACRPhantom_GRAPPA/ptb_resolutionphantom_fully_ismrmrd.h5')

In [None]:
# 2. calculate CSM
preprocessed_data = mr.preprocess_acquisition_data(mr_acq)
preprocessed_data.sort()

csm = mr.CoilSensitivityData()
csm.smoothness = 50
csm.calculate(preprocessed_data)

In [None]:
# 3. calculate image template
recon = mr.FullySampledReconstructor()
recon.set_input(preprocessed_data)
recon.process()
im_mr = recon.get_output()

In [None]:
# Display it
plt.figure();
csm_arr = numpy.abs(csm.as_array())
im_mr_arr = numpy.abs(im_mr.as_array())

plot_2d_image([1,3,1], csm_arr[0, 0, :, :], 'Coil 0', cmap="Greys_r")
plot_2d_image([1,3,2], csm_arr[2, 0, :, :], 'Coil 2', cmap="Greys_r")
plot_2d_image([1,3,3], im_mr_arr[0, :, :], 'Im', cmap="Greys_r")

In [None]:
encode_step_1 = preprocessed_data.parameter_info('kspace_encode_step_1')

In [None]:
import random

ky_0_idx = len(encode_step_1)//2
ky_us_idx = numpy.concatenate((numpy.arange(0,57), numpy.arange(86,len(encode_step_1))), axis=0)
ky_num_fs = 20
ky_num_us = 60
acq_idx_t1 = numpy.arange(ky_0_idx-ky_num_fs//2, ky_0_idx+ky_num_fs//2)
acq_idx_t1 = numpy.concatenate((acq_idx_t1, numpy.asarray(random.sample(list(ky_us_idx), ky_num_us))), axis=0)

acq_idx_t2 = numpy.arange(ky_0_idx-ky_num_fs//2, ky_0_idx+ky_num_fs//2)
acq_idx_t2 = numpy.concatenate((acq_idx_t2, numpy.asarray(random.sample(list(ky_us_idx), ky_num_us))), axis=0)

In [None]:
plt.figure()
plt.plot(encode_step_1, numpy.ones(encode_step_1.shape), 'b.')
plt.plot(encode_step_1[acq_idx_t1], numpy.ones(encode_step_1[acq_idx_t1].shape), 'r.')
plt.plot(encode_step_1[acq_idx_t2], numpy.ones(encode_step_1[acq_idx_t2].shape), 'g.')

In [None]:
# Random sampling
if True:
    acq_dat_t1 = preprocessed_data.new_acquisition_data(empty=True)

    # Create raw data
    for jnd in range(len(acq_idx_t1)):
        cacq = preprocessed_data.acquisition(acq_idx_t1[jnd])
        acq_dat_t1.append_acquisition(cacq)
        
    acq_dat_t1.sort()     
    acq_mod_mr_t1 = mr.AcquisitionModel(acq_dat_t1, im_mr)
    acq_mod_mr_t1.set_coil_sensitivity_maps(csm)
    
    
    acq_dat_t2 = preprocessed_data.new_acquisition_data(empty=True)

    # Create raw data
    for jnd in range(len(acq_idx_t2)):
        cacq = preprocessed_data.acquisition(acq_idx_t2[jnd])
        acq_dat_t2.append_acquisition(cacq)
        
    acq_dat_t2.sort()     
    acq_mod_mr_t2 = mr.AcquisitionModel(acq_dat_t2, im_mr)
    acq_mod_mr_t2.set_coil_sensitivity_maps(csm)
    
else:
    acq_mod_mr_t1 = mr.AcquisitionModel(preprocessed_data, im_mr)
    acq_mod_mr_t1.set_coil_sensitivity_maps(csm)

    acq_mod_mr_t2 = mr.AcquisitionModel(preprocessed_data, im_mr)
    acq_mod_mr_t2.set_coil_sensitivity_maps(csm)


# Simulate raw data

Here we will use the acquisition models to create simulated raw data and then do a simple reconstruction to have some initial images (i.e. starting point) for our gradient descent algorithms. For each modality we will:

 * Fill an image template (`im_mr`, `im_pet`, `im_ct`)
 * Create raw data (`raw_mr`, `raw_pet`, `raw_ct`)
 * Reconstruct an initial guess of our image using `backward`/`adjoint`

In [None]:
# MR
im_mr_t1 = crop_and_fill(im_mr, T1_arr)
raw_mr_t1 = acq_mod_mr_t1.forward(im_mr_t1)
bwd_mr_t1 = acq_mod_mr_t1.backward(raw_mr_t1)

im_mr_t2 = crop_and_fill(im_mr, T2_arr)
raw_mr_t2 = acq_mod_mr_t2.forward(im_mr_t2)
bwd_mr_t2 = acq_mod_mr_t2.backward(raw_mr_t2)


In [None]:
# Display it
plt.figure();
slice_show = bwd_mr_t1.as_array().shape[0]//2
plot_2d_image([1,2,1], numpy.abs(bwd_mr_t2.as_array())[slice_show, :, :], 'T2', cmap="Greys_r")
plot_2d_image([1,2,2], numpy.abs(bwd_mr_t1.as_array())[slice_show, :, :], 'T1', cmap="Greys_r")

JOINT TV

In [None]:
class ProjectionMap(LinearOperator):
    
    def __init__(self, domain_geometry, index, range_geometry=None):
        
        self.index = index
        if range_geometry is None:
            range_geometry = domain_geometry.geometries[self.index]
            
        super(ProjectionMap, self).__init__(domain_geometry=domain_geometry, 
                                           range_geometry=range_geometry)   
        
    def direct(self,x,out=None):
                        
        if out is None:
            return x.get_item(self.index)
        else:
            out.fill(x.get_item(self.index))
    
    def adjoint(self,x, out=None):
        
        if out is None:
            tmp = self.domain_geometry().allocate()
            tmp[self.index].fill(x)            
            return tmp
        else:
            out[self.index].fill(x) 
  

In [None]:
class SmoothJointTV(Function):
              
    def __init__(self, epsilon, axis, lambda_par):
                
        r'''
        :param epsilon: smoothing parameter making MixedL21Norm differentiable 
        '''

        #TODO L=??
        super(SmoothJointTV, self).__init__(L=numpy.sqrt(8))
        
        # smoothing parameter
        self.epsilon = epsilon   
        
        # GradientOperator
        #self.grad = GradientOperator(bwd_mr_t1, backend='numpy', correlation='SpaceChannels')
        
        ig = ImageGeometry(voxel_num_z = 1,voxel_num_y = 3, voxel_num_x = 4)
        FDy = FiniteDifferenceOperator(bwd_mr_t1, direction=1)
        FDx = FiniteDifferenceOperator(bwd_mr_t1, direction=2)
        self.grad = BlockOperator(FDy, FDx)
        
        
        # Which variable to differentiate
        self.axis = axis
        
        if self.epsilon==0:
            raise ValueError('Working with smooth JTV atm')
            
        self.lambda_par=lambda_par    
                                    
                            
    def __call__(self, x):
        
        r""" x is BlockDataContainer that contains (u,v). Actually x is a BlockDataContainer that contains 2 BDC.
        """
        if not isinstance(x, BlockDataContainer):
            raise ValueError('__call__ expected BlockDataContainer, got {}'.format(type(x))) 

        tmp = numpy.abs((self.lambda_par*self.grad.direct(x.get_item(0)).pnorm(2).power(2) + (1-self.lambda_par)*self.grad.direct(x.get_item(1)).pnorm(2).power(2)+\
              self.epsilon**2).sqrt().sum())
        #print('JTV', tmp)
        return tmp    
                        
             
    def gradient(self, x, out=None):
        
        denom = (self.lambda_par*self.grad.direct(x.get_item(0)).pnorm(2).power(2) + (1-self.lambda_par)*self.grad.direct(x.get_item(1)).pnorm(2).power(2)+\
              self.epsilon**2).sqrt()         
        
        if self.axis==0:            
            num = self.lambda_par*self.grad.direct(x.get_item(0))                        
        else:            
            num = (1-self.lambda_par)*self.grad.direct(x.get_item(1))            

        if out is None:    
            tmp = self.grad.range.allocate()
            tmp[self.axis].fill(self.grad.adjoint(num.divide(denom)))
            return tmp
        else:                                
            self.grad.adjoint(num.divide(denom), out=out[self.axis])
        
            

In [None]:
from cil.framework import ImageGeometry
from cil.optimisation.operators import FiniteDifferenceOperator, BlockOperator
ig = ImageGeometry(voxel_num_z = 1,voxel_num_y = 3, voxel_num_x = 4)
FDy = FiniteDifferenceOperator(bwd_mr_t1, direction=1)
FDx = FiniteDifferenceOperator(bwd_mr_t1, direction=2)
B = BlockOperator(FDy, FDx)
x = bwd_mr_t1.clone()
res = B.direct(x)

In [None]:
im_opt_t1 = numpy.squeeze(numpy.abs(bwd_mr_t1.as_array()))

plt.figure()
plot_2d_image([1,2,1], numpy.squeeze(numpy.abs(bwd_mr_t1.as_array()[0, 70:130, 70:130])), 'T1 pseudo inv', cmap="Greys_r")
plot_2d_image([1,2,2], numpy.squeeze(numpy.abs(res.get_item(0).as_array()[0, 70:130, 70:130])), 'FD_x_y', cmap="Greys_r")


In [None]:
num_it_fista = 10
x_init = bwd_mr_t1.clone()

t1 = time.time()
f = LeastSquares(acq_mod_mr_t1, raw_mr_t1, c=1)
print('LS {:3.2f}s'.format((time.time() - t1)))

G = ZeroFunction()

# Run FISTA for least squares
t1 = time.time()
fista = FISTA(x_init=x_init, f=f, g=G)
fista.max_iteration = num_it_fista
fista.update_objective_interval = 2
print('SETUP {:3.2f}s'.format((time.time() - t1)))

t1 = time.time()
fista.run(100, verbose=True)
print('FISTA {:3.2f}s'.format((time.time() - t1)))

im_opt_t1 = numpy.squeeze(numpy.abs(fista.get_output().as_array()))

In [None]:
# Run FISTA with TV
alpha = 0.2
G = alpha * FGP_TV(max_iteration=10, device='cpu')

t1 = time.time()
fista = FISTA(x_init=x_init, f=f, g=G)
fista.max_iteration = num_it_fista
fista.update_objective_interval = 2
print('SETUP {:3.2f}s'.format((time.time() - t1)))

t1 = time.time()
fista.run(100, verbose=True)
print('FISTA {:3.2f}s'.format((time.time() - t1)))

im_opt_t1_tv = numpy.squeeze(numpy.abs(fista.get_output().as_array()))

In [None]:
plt.figure()
plot_2d_image([2,3,1], numpy.squeeze(numpy.abs(bwd_mr_t1.as_array()[0, 70:130, 70:130])), 'T1 pseudo inv', cmap="Greys_r")
plot_2d_image([2,3,2], im_opt_t1[70:130, 70:130], 'T1 FISTA', cmap="Greys_r")
plot_2d_image([2,3,5], im_opt_t1_tv[70:130, 70:130], 'T1 FISTA TV', cmap="Greys_r")
plot_2d_image([2,3,3], numpy.squeeze(numpy.abs(im_mr_t1.as_array()[0, 70:130, 70:130])), 'T1 GT', cmap="Greys_r")
plt.show() 

In [None]:
alpha1 = 0.001
alpha2 = 0.001
lambda_par = 0.5
epsilon = 1e-4

bg = BlockGeometry(bwd_mr_t1, bwd_mr_t2)

L1 = ProjectionMap(bg,index=0)
L2 = ProjectionMap(bg,index=1)

f1 = 0.5*L2NormSquared(b=raw_mr_t1)
f2 = 0.5*L2NormSquared(b=raw_mr_t2)

JTV1 = alpha1*SmoothJointTV(epsilon=epsilon, axis=0, lambda_par = lambda_par )
JTV2 = alpha2*SmoothJointTV(epsilon=epsilon, axis=1, lambda_par = 1-lambda_par)
objective1 = OperatorCompositionFunction(f1, CompositionOperator(acq_mod_mr_t1,L1)) + JTV1
objective2 = OperatorCompositionFunction(f2, CompositionOperator(acq_mod_mr_t2,L2)) + JTV2

In [None]:
x0 = bg.allocate(0.0)

for i in range(5):
    

    gd1 = GD(x0, objective1, alpha=1e9, \
          max_iteration = 10, update_objective_interval = 1)
    gd1.run(verbose=1)
    

    gd2 = GD(gd1.solution, objective2, alpha=1e9,\
          max_iteration = 10, update_objective_interval = 1)
    gd2.run(verbose=1) 
    
    x0.fill(gd2.solution)
    
    print(i)

In [None]:
im_opt_t1 = numpy.squeeze(numpy.abs(x0.get_item(0).as_array()))
im_opt_t2 = numpy.squeeze(numpy.abs(x0.get_item(1).as_array()))

plt.figure()
plot_2d_image([2,3,1], numpy.squeeze(numpy.abs(bwd_mr_t1.as_array()[0, 70:130, 70:130])), 'T1 pseudo inv', cmap="Greys_r")
plot_2d_image([2,3,2], im_opt_t1[70:130, 70:130], 'T1 JTV', cmap="Greys_r")
plot_2d_image([2,3,3], numpy.squeeze(numpy.abs(im_mr_t1.as_array()[0, 70:130, 70:130])), 'T1 GT', cmap="Greys_r") 

plot_2d_image([2,3,4], numpy.squeeze(numpy.abs(bwd_mr_t2.as_array()[0, 70:130, 70:130])), 'T2 pseudo inv', cmap="Greys_r")
plot_2d_image([2,3,5], im_opt_t2[70:130, 70:130], 'T2 JTV', cmap="Greys_r")
plot_2d_image([2,3,6], numpy.squeeze(numpy.abs(im_mr_t2.as_array()[0, 70:130, 70:130])), 'T2 GT', cmap="Greys_r") 