# Joint TV for multi-contrast MR
This demonstration shows how to do a synergistic reconstruction of two MR images with different contrast. Both MR images show the same underlying anatomy but of course with different contrast. In order to make use of this similarity a joint total variation (TV) operator is used as a regularisation in an iterative image reconstruction approach. 

This demo is a jupyter notebook, i.e. intended to be run step by step.
You could export it as a Python file and run it one go, but that might
make little sense as the figures are not labelled.


Author: Christoph Kolbitsch, Evangelos Papoutsellis, Edoardo Pasca
First version: 16th of June 2021  

CCP PETMR Synergistic Image Reconstruction Framework (SIRF).  
Copyright 2021 Rutherford Appleton Laboratory STFC.    
Copyright 2021 Physikalisch-Technische Bundesanstalt.

This is software developed for the Collaborative Computational
Project in Positron Emission Tomography and Magnetic Resonance imaging
(http://www.ccppetmr.ac.uk/).

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# Initial set-up

In [None]:
# Make sure figures appears inline and animations works
%matplotlib notebook

In [None]:
# Make sure everything is installed that we need
!pip install brainweb nibabel --user

In [None]:
# Initial imports etc
import numpy
from numpy.linalg import norm
import matplotlib.pyplot as plt

import os
import sys
import shutil
import brainweb
from tqdm.auto import tqdm

import time

# Import MR functionality
import sirf.Gadgetron as mr

# Import CIL functionality
from sirf.Utilities import examples_data_path
from cil.framework import  AcquisitionGeometry, BlockDataContainer, BlockGeometry, ImageGeometry
from cil.optimisation.functions import Function, OperatorCompositionFunction, SmoothMixedL21Norm, L1Norm, L2NormSquared, BlockFunction, MixedL21Norm, IndicatorBox, TotalVariation, LeastSquares, ZeroFunction
from cil.optimisation.operators import GradientOperator, BlockOperator, ZeroOperator, CompositionOperator, LinearOperator, FiniteDifferenceOperator
from cil.optimisation.algorithms import PDHG, FISTA, GD
from cil.plugins.ccpi_regularisation.functions import FGP_TV

# Utilities

In [None]:
# First define some handy function definitions
# To make subsequent code cleaner, we have a few functions here. You can ignore
# ignore them when you first see this demo.

def plot_2d_image(idx,vol,title,clims=None,cmap="viridis"):
    """Customized version of subplot to plot 2D image"""
    plt.subplot(*idx)
    plt.imshow(vol,cmap=cmap)
    if not clims is None:
        plt.clim(clims)
    plt.colorbar()
    plt.title(title)
    plt.axis("off")

def crop_and_fill(templ_im, vol):
    """Crop volumetric image data and replace image content in template image object"""
    # Get size of template image and crop
    idim_orig = templ_im.as_array().shape
    idim = (1,)*(3-len(idim_orig)) + idim_orig
    offset = (numpy.array(vol.shape) - numpy.array(idim)) // 2
    vol = vol[offset[0]:offset[0]+idim[0], offset[1]:offset[1]+idim[1], offset[2]:offset[2]+idim[2]]
    
    # Make a copy of the template to ensure we do not overwrite it
    templ_im_out = templ_im.copy()
    
    # Fill image content 
    templ_im_out.fill(numpy.reshape(vol, idim_orig))
    return(templ_im_out)

### Joint TV reconstruction of two MR images

Assume we want to reconstruct two MR images $u$ and $v$ and utilse the similarity between both images using a joint TV ($JTV$) operator we can formulate the reconstruction problem as:

$$
\begin{equation}
(u^{*}, v^{*}) \in \underset{u,v}{\operatorname{argmin}} \frac{1}{2} \| A_{1} u - g\|^{2}_{2} +  \frac{1}{2} \| A_{2} v - h\|^{2}_{2} + \alpha\,\mathrm{JTV}_{\eta, \lambda}(u, v) 
\end{equation}
$$

* $JTV_{\eta, \lambda}(u_{1}, u_{2}) = \sum \sqrt{ \lambda|\nabla u_{1}|^{2} + (1-\lambda)|\nabla u_{2}|^{2} + \eta^{2}}$
* $A_{1}$, $A_{2}$: __MR__ `AcquisitionModel`
* $g_{1}$, $g_{2}$ __MR__ `AcquisitionData`


### Solving this problem 

In order to solve the above minimization problem, we will use an alternating minimisation approach, where one variable is fixed and we solve wrt to the other variable:

$$
\begin{align*}
u^{k+1} & = \underset{u}{\operatorname{argmin}} \frac{1}{2} \| A_{1} u - g\|^{2}_{2} + \alpha_{1}\,\mathrm{JTV}_{\eta, \lambda}(u, v^{k}) \quad \text{subproblem 1}\\
v^{k+1} & = \underset{v}{\operatorname{argmin}} \frac{1}{2} \| A_{2} v - h\|^{2}_{2} + \alpha_{2}\,\mathrm{JTV}_{\eta, 1-\lambda}(u^{k+1}, v) \quad \text{subproblem 2}\\
\end{align*}$$

We are going to use a gradient descent approach to solve each of these subproblems alternatingly.

The *regularisation parameter* `alpha` should be different for each subproblem. But not to worry at this stage. Maybe we should use $\alpha_{1}, \alpha_{2}$ in front of the two JTVs and a $\lambda$, $1-\lambda$ for the first JTV and $1-\lambda$, $\lambda$, for the second JTV with $0<\lambda<1$.


This notebook builds on several other notebooks and hence certain steps will be carried out with minimal documentation. If you want more explainations, then we would like to ask you to refer to the corresponding notebooks which are mentioned in the following list. The steps we are going to carry out are

  - (A) Get a T1 and T2 map from brainweb which we are going to use as ground truth $u_{gt}$ and $v_{gt}$ for our reconstruction (further information: `introduction` notebook)
  
  - (B) Create __MR__ `AcquisitionModel` $A_{1}$ and $A_{2}$ and simulate undersampled __MR__ `AcquisitionData` $g_{1}$ and $g_{2}$ (further information: `acquisition_model_mr_pet_ct` notebook)
  
  - (C) Set up the joint TV reconstruction problem
  
  - (D) Solve the joint TV reconstruction problem (further information on gradient descent: `gradient_descent_mr_pet_ct` notebook)

# (A) Get brainweb data

We will download and use data from the brainweb.

In [None]:
fname, url= sorted(brainweb.utils.LINKS.items())[0]
files = brainweb.get_file(fname, url, ".")
data = brainweb.load_file(fname)

brainweb.seed(1337)

In [None]:
for f in tqdm([fname], desc="mMR ground truths", unit="subject"):
    vol = brainweb.get_mmr_fromfile(f, petNoise=1, t1Noise=0.75, t2Noise=0.75, petSigma=1, t1Sigma=1, t2Sigma=1)

In [None]:
T2_arr  = vol['T2']
T1_arr   = vol['T1']

# Normalise image data
T2_arr /= numpy.max(T2_arr)
T1_arr /= numpy.max(T1_arr)

In [None]:
# Display it
plt.figure();
slice_show = T2_arr.shape[0]//2
plot_2d_image([1,2,1], T2_arr[slice_show, 100:-100, 100:-100], 'T2', cmap="Greys_r")
plot_2d_image([1,2,2], T1_arr[slice_show, 100:-100, 100:-100], 'T1', cmap="Greys_r")

# (B) Simulate undersampled MR AcquisitionData

In [None]:
# Create MR AcquisitionData
mr_acq = mr.AcquisitionData(examples_data_path('MR') + '/simulated_MR_2D_cartesian_Grappa2.h5')

In [None]:
# Calculate CSM
preprocessed_data = mr.preprocess_acquisition_data(mr_acq)

csm = mr.CoilSensitivityData()
csm.smoothness = 50
csm.calculate(preprocessed_data)

In [None]:
# Calculate image template
recon = mr.FullySampledReconstructor()
recon.set_input(preprocessed_data)
recon.process()
im_mr = recon.get_output()

In [None]:
# Display it
plt.figure();
csm_arr = numpy.abs(csm.as_array())

plot_2d_image([1,2,1], csm_arr[0, 0, :, :], 'Coil 0', cmap="Greys_r")
plot_2d_image([1,2,2], csm_arr[2, 0, :, :], 'Coil 2', cmap="Greys_r")

We are going to use these coilmaps to simulate our __MR__ raw data. These coils maps were obtained from a Shepp-Logan-phantom which unfortunately has got these two ellipses in the centre with very little intensity. Therefore, the coil maps cannot be estimated there and contain mainly noise. Keep that in mind, we will compe back to this observation later.

Next we are going to create the two __MR__ `AcquisitionModel` $A_{1}$ and $A_{2}$ 

In [None]:
# Create two MR acquisition models
A1 = mr.AcquisitionModel(preprocessed_data, im_mr)
A1.set_coil_sensitivity_maps(csm)

A2 = mr.AcquisitionModel(preprocessed_data, im_mr)
A2.set_coil_sensitivity_maps(csm)

and simulate undersampled __MR__ `AcquisitionData` $g_{1}$ and $g_{2}$ 

In [None]:
# MR
u_gt = crop_and_fill(im_mr, T1_arr)
g1 = A1.forward(u_gt)

v_gt = crop_and_fill(im_mr, T2_arr)
g2 = A2.forward(v_gt)

Just to check we are going to apply the backward/adjoint operation to do a simply image reconstruction.

In [None]:
# Simple reconstruction
u_simple = A1.backward(g1)
v_simple = A2.backward(g2)

In [None]:
# Display it
plt.figure();
plot_2d_image([1,2,1], numpy.abs(u_simple.as_array())[0, 70:180, 70:180], '$u_{simple}$', cmap="Greys_r")
plot_2d_image([1,2,2], numpy.abs(v_simple.as_array())[0, 70:180, 70:180], '$v_{simple}$', cmap="Greys_r")

These images look quite poor compared to the ground truth input images, because they are reconstructed from an undersampled k-space. In addition, you can see a strange "structure" going through the centre of the brain. This has something to do with the coil maps. As mentioned above, our coil maps have these two "holes" in the centre and this creates this artefacts. Nevertheless, this is not going to be a problem for our reconstruction as we will see later on.

# (C) Set up the joint TV reconstruction problem

So far we have used mainly __SIRF__ functionality, now we are going to use __CIL__ in order to set up the reconstruction problem and then solve it. In order to be able to reconstruct both $u$ and $v$ at the same time, we will make use of `BlockDataContainer`. In the following we will define an operator which allows us to access these two items and then we define a function to calculate the joint TV.

In [None]:
class ProjectionMap(LinearOperator):
    
    def __init__(self, domain_geometry, index, range_geometry=None):
        
        self.index = index
        if range_geometry is None:
            range_geometry = domain_geometry.geometries[self.index]
            
        super(ProjectionMap, self).__init__(domain_geometry=domain_geometry, 
                                           range_geometry=range_geometry)   
        
    def direct(self,x,out=None):
                        
        if out is None:
            return x[self.index]
        else:
            out.fill(x[self.index])
    
    def adjoint(self,x, out=None):
        
        if out is None:
            tmp = self.domain_geometry().allocate()
            tmp[self.index].fill(x)            
            return tmp
        else:
            out[self.index].fill(x) 
  

In [None]:
class SmoothJointTV(Function):
              
    def __init__(self, epsilon, axis, lambda_par):
                
        r'''
        :param epsilon: smoothing parameter making MixedL21Norm differentiable 
        '''

        #TODO L=??
        super(SmoothJointTV, self).__init__(L=numpy.sqrt(8))
        
        # smoothing parameter
        self.epsilon = epsilon   
        
        # GradientOperator
        ig = ImageGeometry(voxel_num_z = 1,voxel_num_y = 3, voxel_num_x = 4)
        FDy = FiniteDifferenceOperator(u_simple, direction=1)
        FDx = FiniteDifferenceOperator(u_simple, direction=2)
        self.grad = BlockOperator(FDy, FDx)
        
        
        # Which variable to differentiate
        self.axis = axis
        
        if self.epsilon==0:
            raise ValueError('Working with smooth JTV atm')
            
        self.lambda_par=lambda_par    
                                    
                            
    def __call__(self, x):
        
        r""" x is BlockDataContainer that contains (u,v). Actually x is a BlockDataContainer that contains 2 BDC.
        """
        if not isinstance(x, BlockDataContainer):
            raise ValueError('__call__ expected BlockDataContainer, got {}'.format(type(x))) 

        tmp = numpy.abs((self.lambda_par*self.grad.direct(x[0]).pnorm(2).power(2) + (1-self.lambda_par)*self.grad.direct(x[1]).pnorm(2).power(2)+\
              self.epsilon**2).sqrt().sum())
        #print('JTV', tmp)
        return tmp    
                        
             
    def gradient(self, x, out=None):
        
        denom = (self.lambda_par*self.grad.direct(x[0]).pnorm(2).power(2) + (1-self.lambda_par)*self.grad.direct(x[1]).pnorm(2).power(2)+\
              self.epsilon**2).sqrt()         
        
        if self.axis==0:            
            num = self.lambda_par*self.grad.direct(x[0])                        
        else:            
            num = (1-self.lambda_par)*self.grad.direct(x[1])            

        if out is None:    
            tmp = self.grad.range.allocate()
            tmp[self.axis].fill(self.grad.adjoint(num.divide(denom)))
            return tmp
        else:                                
            self.grad.adjoint(num.divide(denom), out=out[self.axis])
        
            

Now we are going to put everything together and define our two objective functions which solve the two subproblems which we defined at the beginning

In [None]:
alpha1 = 0.02
alpha2 = 0.02
lambda_par = 0.5
epsilon = 1e-12

bg = BlockGeometry(u_simple, v_simple)

L1 = ProjectionMap(bg, index=0)
L2 = ProjectionMap(bg, index=1)

f1 = 0.5*L2NormSquared(b=g1)
f2 = 0.5*L2NormSquared(b=g2)

JTV1 = alpha1*SmoothJointTV(epsilon=epsilon, axis=0, lambda_par = lambda_par )
JTV2 = alpha2*SmoothJointTV(epsilon=epsilon, axis=1, lambda_par = 1-lambda_par)
objective1 = OperatorCompositionFunction(f1, CompositionOperator(A1, L1)) + JTV1
objective2 = OperatorCompositionFunction(f2, CompositionOperator(A2, L2)) + JTV2

# (D) Solve the joint TV reconstruction problem

In [None]:
# We start with zero-filled images
x0 = bg.allocate(0.0)

# We use a fixed step-size for the gradient descent approach
step_size = 0.1

# We are also going to log the value of the objective functions
obj1_val_it = []
obj2_val_it = []

for i in range(10):
    
    gd1 = GD(x0, objective1, step_size=step_size, \
          max_iteration = 4, update_objective_interval = 1)
    gd1.run(verbose=1)
    
    # We skip the first one because it gets repeated
    obj1_val_it.extend(gd1.objective[1:])
    
    # Here we are going to do a little "trick" in order to better see, when each subproblem is optimised, we
    # are going to append NaNs to the objective function which is currently not optimised. The NaNs will not
    # show up in the final plot and hence we can nicely see each subproblem.
    obj2_val_it.extend(numpy.ones_like(gd1.objective[1:])*numpy.nan)
    
    gd2 = GD(gd1.solution, objective2, step_size=step_size, \
          max_iteration = 4, update_objective_interval = 1)
    gd2.run(verbose=1)
    
    obj2_val_it.extend(gd2.objective[1:])
    obj1_val_it.extend(numpy.ones_like(gd2.objective[1:])*numpy.nan)
    
    x0.fill(gd2.solution)
    
    print('* * * * * * Outer Iteration ', i, ' * * * * * *\n')

Finally we can look at the images $u_{jtv}$ and $v_{jtv}$ and compare them to the simple reconstruction $u_{simple}$ and $v_{simple}$ and the original ground truth images.

In [None]:
u_jtv = numpy.squeeze(numpy.abs(x0[0].as_array()))
v_jtv = numpy.squeeze(numpy.abs(x0[1].as_array()))

plt.figure()
plot_2d_image([2,3,1], numpy.squeeze(numpy.abs(u_simple.as_array()[0, 70:180, 70:180])), '$u_{simple}$', cmap="Greys_r")
plot_2d_image([2,3,2], u_jtv[70:180, 70:180], '$u_{JTV}$', cmap="Greys_r")
plot_2d_image([2,3,3], numpy.squeeze(numpy.abs(u_gt.as_array()[0, 70:180, 70:180])), '$u_{gt}$', cmap="Greys_r") 

plot_2d_image([2,3,4], numpy.squeeze(numpy.abs(v_simple.as_array()[0, 70:180, 70:180])), '$v_{simple}$', cmap="Greys_r")
plot_2d_image([2,3,5], v_jtv[70:180, 70:180], '$v_{JTV}$', cmap="Greys_r")
plot_2d_image([2,3,6], numpy.squeeze(numpy.abs(v_gt.as_array()[0, 70:180, 70:180])), '$v_{gt}$', cmap="Greys_r") 

And let's look at the objective functions

In [None]:
plt.figure()
plt.plot(obj1_val_it, 'o-', label='subproblem 1')
plt.plot(obj2_val_it, '+-', label='subproblem 2')
plt.xlabel('Number of iterations')
plt.ylabel('Value of objective function')
plt.title('Objective functions')
plt.legend()

# Logarithmic y-axis
plt.yscale('log')

# Next steps

The above is a good demonstration for a synergistic image reconstruction of two different images. The following gives a few suggestions of what to do next and also how to extend this notebook to other applications.

## Number of iterations
In our problem we have several regularisation parameters such as $\alpha_{1}$, $\alpha_{2}$ and $\lambda$. In addition, the number of inner iterations for each subproblem (currently set to 3) and the number of outer iterations (currently set to 10) also determine the final solution. Of course, for infinite number of total iterations it shouldn't matter but usually we don't have that much time.

__TODO__: Change the number of iterations and see what happens to the objective functions. For a given number of total iterations, do you think it is better to have a high number of inner or high number of outer iterations? Why?

## Spatial misalignment
In the above example we simulated our data such that there is a perfect spatial match between $u$ and $v$. For real world applications this usually cannot be assumed. 

__TODO__: Add spatial misalignment between $u$ and $v$. This can be achieved e.g. by calling `numpy.roll` on `T2_arr` before calling `v_gt = crop_and_fill(im_mr, T2_arr)`. What is the effect on the reconstructed images? For a more "advanced" misalignment, have a look at notebook `BrainWeb`.

__TODO__: One way to minimize spatial misalignment is to use image registration to ensure both $u$ and $v$ are well aligned. In the notebook `sirf_registration` you find information about how to register two images and also how to resample one image based on the spatial transformation estimated from the registration. Try to use this to correct for the misalignment you introduced above. For a real world example, at which point in the code would you have to carry out the registration+resampling? (some more information can also be found at the end of notebook `de_Pierro_MAPEM`)

## Pathologies
The images $u$ and $v$ show the same anatomy, just with a different contrast. Clinically more useful are of course images which show complementary image information.

__TODO__: Add a pathology to either $u$ and $v$ and see how this effects the reconstruction. For something more advanced, have a loot at the notebook `BrainWeb`.

## Single anatomical prior
So far we have alternated between two reconstruction problems. Another option is to do a single regularised reconstruction and simply use a previously reconstructed image for regularisation.

__TODO__: Adapt the above code such that $u$ is reconstructed first without regularisation and is then used for a regularised reconstruction of $v$ without any further updates of $u$.

## Complementary k-space trajectories
We used the same k-space trajectory for $u$ and $v$. This is of course not ideal for such an optimisation, because the same k-space trajectory also means the same pattern of undersampling artefacts. Of course the artefacts in each image will be different because of the different image content but it still would be better if $u$ and $v$ were acquired with different k-space trajectories.

__TODO__: Create two different k-space trajectories and compare the results to a reconstruction using the same k-space trajectories. In order to create a new k-space trajectory, we can take a subset of the available `AcquisitionData`. Here is a short tutorial:

After `preprocessed_data = mr.preprocess_acquisition_data(mr_acq)` we can find out which k-space points have been acquired in this acquisition using:

```encode_step_1 = preprocessed_data.parameter_info('kspace_encode_step_1')```

then we can define an index list of k-space points from `encode_step_1` which we want to use:

```acq_new_idx = numpy.arange(0, len(encode_step_1))```

This is of course boring, because we are using all k-space points but I am sure you will come up with something more clever. Now we create an empty `AcquisitionData` object and then add the `Acquisitions` which we have selected:

```
acq_new = preprocessed_data.new_acquisition_data(empty=True)

for jnd in range(len(acq_new_idx)):
    cacq = preprocessed_data.acquisition(acq_new_idx[jnd])
    acq_new.append_acquisition(cacq)
    
acq_new.sort() 
```

Now we have got our new `AcquisitionData` we can create a new `AcquisitionModel` from it:

```
A1_new = mr.AcquisitionModel(acq_new, im_mr)
A1_new.set_coil_sensitivity_maps(csm)
```


## Other regularisation options
In this example we used a TV-based regularisation, but of course other regularisers could also be used, such as directional TV.

__TODO__: Have a look at the __CIL__ notebook `02_Dynamic_CT` and adapt the `SmoothJointTV` class above to use directional TV. 