### Joint TV reconstruction ( sequence MR )

$$
\begin{equation}
(u^{*}, v^{*}) \in \underset{u,v}{\operatorname{argmin}} \frac{1}{2} \| A_{1} u - g\|^{2}_{2} +  \frac{1}{2} \| A_{2} v - h\|^{2}_{2} + \alpha\,\mathrm{JTV}_{\eta, \lambda}(u, v) 
\end{equation}
$$

* $JTV_{\eta, \lambda}(u_{1}, u_{2}) = \sum \sqrt{ \lambda|\nabla u_{1}|^{2} + (1-\lambda)|\nabla u_{2}|^{2} + \eta^{2}}$
* $A_{1}$, $A_{2}$: `AcquisitionModel`
* $g_{1}$, $g_{2}$ MR `AcquisitionData`


## How to solve 

Alternating minimisation problems, where one variable is fixed and we solve wrt to the other variable:

$$
\begin{align*}
u^{k+1} & = \underset{u}{\operatorname{argmin}} \frac{1}{2} \| A_{1} u - g\|^{2}_{2} + \alpha_{1}\,\mathrm{JTV}_{\eta, \lambda}(u, v^{k}) \\
v^{k+1} & = \underset{v}{\operatorname{argmin}} \frac{1}{2} \| A_{2} v - h\|^{2}_{2} + \alpha_{2}\,\mathrm{JTV}_{\eta, 1-\lambda}(u^{k+1}, v)
\end{align*}$$

<!-- * The *regularisation parameter* `alpha` should be different for each subproblem. But not to worry at this stage. Maybe we should use $\alpha_{1}, \alpha_{2}$ in front of the two JTVs and a $\lambda$, $1-\lambda$ for the first JTV and $1-\lambda$, $\lambda$, for the second JTV with $0<\lambda<1$. -->


In [None]:
# Import libraries

from cil.framework import  AcquisitionGeometry, BlockDataContainer, BlockGeometry
from cil.optimisation.functions import Function, OperatorCompositionFunction, SmoothMixedL21Norm, L1Norm, L2NormSquared, BlockFunction, MixedL21Norm, IndicatorBox, TotalVariation, LeastSquares
from cil.optimisation.operators import GradientOperator, BlockOperator, ZeroOperator, CompositionOperator,LinearOperator
from cil.optimisation.algorithms import PDHG, FISTA, GD
from cil.plugins.astra.operators import ProjectionOperator
from cil.plugins.astra.processors import FBP
from cil.plugins import TomoPhantom
from cil.utilities.display import show2D
from cil.utilities import noise

import matplotlib.pyplot as plt

import numpy as np

In [None]:
# Detectors
N = 256
detectors =  N

# Angles
angles = np.linspace(0,180,180, dtype='float32')

# Setup acquisition geometry
ag = AcquisitionGeometry.create_Parallel2D()\
                        .set_angles(angles)\
                        .set_panel(detectors, pixel_size=0.1)
# Get image geometry
ig = ag.get_ImageGeometry()

# Get phantom
phantom1 = TomoPhantom.get_ImageData(1, ig)
phantom2 = TomoPhantom.get_ImageData(2, ig)

show2D([phantom1, phantom2], origin="upper", cmap="inferno")

In [None]:
A1 = ProjectionOperator(ig, ag, device = 'gpu')
A2 = ProjectionOperator(ig, ag, device = 'gpu')

# Create acqusition data, e.g., 2 MR sequences, or MR, CT
sino1 = A1.direct(phantom1)
sino2 = A2.direct(phantom2)

# Simulate Gaussian noise 

n1 = np.random.normal(0, 0.5, size = ag.shape)
n2 = np.random.normal(0, 1.0, size = ag.shape)

# Acquisition data g
g = ag.allocate()
g.fill(n1 + sino1.array)
g.array[g.array<0]=0

# Acquisition data b
h = ag.allocate()
h.fill(n2 + sino2.array)
h.array[h.array<0]=0

# Show phantom and sinograms
show2D([phantom1, phantom2,
        g, h], title = ['1st modality','2n modality','Noisy Sinogram ($g$)','Noisy Sinogram ($h$)'], num_cols=2, cmap = 'inferno', origin="upper")

In [None]:
class ProjectionMap(LinearOperator):
    
    def __init__(self, domain_geometry, index, range_geometry=None):
        
        self.index = index
        if range_geometry is None:
            range_geometry = domain_geometry.geometries[self.index]
            
        super(ProjectionMap, self).__init__(domain_geometry=domain_geometry, 
                                           range_geometry=range_geometry)   
        
    def direct(self,x,out=None):
                        
        if out is None:
            return x.get_item(self.index)
        else:
            out.fill(x.get_item(self.index))
    
    def adjoint(self,x, out=None):
        
        if out is None:
            tmp = self.domain_geometry().allocate()
            tmp[self.index].fill(x)            
            return tmp
        else:
            out[self.index].fill(x) 
  

In [None]:
class SmoothJointTV(Function):
              
    def __init__(self, epsilon, axis, lambda_par):
                
        r'''
        :param epsilon: smoothing parameter making MixedL21Norm differentiable 
        '''

        #TODO L=??
        super(SmoothJointTV, self).__init__(L=np.sqrt(8))
        
        # smoothing parameter
        self.epsilon = epsilon   
        
        # GradientOperator
        self.grad = GradientOperator(ig)
        
        # Which variable to differentiate
        self.axis = axis
        
        if self.epsilon==0:
            raise ValueError('Working with smooth JTV atm')
            
        self.lambda_par=lambda_par    
                                    
                            
    def __call__(self, x):
        
        r""" x is BlockDataContainer that contains (u,v). Actually x is a BlockDataContainer that contains 2 BDC.
        """
        if not isinstance(x, BlockDataContainer):
            raise ValueError('__call__ expected BlockDataContainer, got {}'.format(type(x))) 

        tmp = (self.lambda_par*self.grad.direct(x.get_item(0)).pnorm(2).power(2) + (1-self.lambda_par)*self.grad.direct(x.get_item(1)).pnorm(2).power(2)+\
              self.epsilon**2).sqrt().sum()
        return tmp    
                        
             
    def gradient(self, x, out=None):
        
        denom = (self.lambda_par*self.grad.direct(x.get_item(0)).pnorm(2).power(2) + (1-self.lambda_par)*self.grad.direct(x.get_item(1)).pnorm(2).power(2)+\
              self.epsilon**2).sqrt()         
        
        if self.axis==0:            
            num = self.lambda_par*self.grad.direct(x.get_item(0))                        
        else:            
            num = (1-self.lambda_par)*self.grad.direct(x.get_item(1))            

        if out is None:    
            tmp = self.grad.range.allocate()
            tmp[self.axis].fill(self.grad.adjoint(num.divide(denom)))
            return tmp
        else:                                
            self.grad.adjoint(num.divide(denom), out=out[self.axis])
            
            
            

In [None]:
alpha1 = 0.2
alpha2 = 0.3
lambda_par = 0.6
epsilon = 1e-4

bg = BlockGeometry(ig,ig)

L1 = ProjectionMap(bg,index=0)
L2 = ProjectionMap(bg,index=1)

f1 = 0.5*L2NormSquared(b=g)
f2 = 0.5*L2NormSquared(b=h)

JTV1 = alpha1*SmoothJointTV(epsilon=epsilon, axis=0, lambda_par = lambda_par )
JTV2 = alpha2*SmoothJointTV(epsilon=epsilon, axis=1, lambda_par = 1-lambda_par)
objective1 = OperatorCompositionFunction(f1, CompositionOperator(A1,L1)) + JTV1
objective2 = OperatorCompositionFunction(f2, CompositionOperator(A2,L2)) + JTV2

In [None]:
x0 = bg.allocate(0.0)

for i in range(20):
    
    gd1 = GD(x0, objective1, alpha=1e9, \
          max_iteration = 10, update_objective_interval = 10)
    gd1.run(verbose=0)
    
    gd2 = GD(gd1.solution, objective2, alpha=1e9,\
          max_iteration = 10, update_objective_interval = 10)
    gd2.run(verbose=0) 
    
    x0.fill(gd2.solution)
    
    print(i)
    

In [None]:
show2D(gd1.solution, cmap="inferno", origin="upper")
show2D([phantom1, phantom2], cmap="inferno", origin="upper")