### Focus

This notebook employs an toy example to explain: 

* why scaling is necessary when using overlapping locally low rank (LLR) regularization for iterative image reconstruction.

### Intuition

In principle, **LLR regularization enforces low rankness in multi-contrast image series** [1,2]. Assume that we want to reconstruct one set of images with the shape `[N_x, N_y, N_contrast]` with `N_x` and `N_y` being the width and height of the images, and `N_contrast` the number of contrasts, LLR regularization then involves the following steps:

  1. **Extract** 2D local image patches from each 2D image of the 3D multi-contrast image array. The shape of the 2D local image patch `[N_patch_x, N_patch_y]` is much smaller than `[N_x, N_y]`;
  2. The extracted single patch has the shape `[N_patch_x, N_patch_y, N_contrast]`, which is then **flattened** to a 2D array with the shape `[N_patch_x * N_patch_y, N_contrast]`;
  3. **singular value thresholding (SVT)** [3] is applied onto every flattened 2D pathes;
  4. **convert** every thresholded flattened 2D arrays back into their corresponding location in the multi-contrast image arrays.
  - In the 4th step, one has to be careful when using **overlapping** image patches! I will explain this with an toy example.

### Author

  * Zhengguo Tan <zhengguo.tan@gmail.com>


### References

[1] Trzasko J, Manduca A. Local versus global low-rank promotion in dynamic MRI series reconstruction. *Proceedings of the 19th Annual Meeting of ISMRM* (2011) page 4371. 

[2] Zhang T, Pauly JM, Levesque IR. Accelerated parameter mapping with a locally low rank constraint. *Magn Reson Med* (2015). doi: [10.1002/mrm.25161](https://doi.org/10.1002/mrm.25161)

[3] Cai JF, Candès EJ, Shen Z. A singular value thresholding algorithm for matrix completion. *SIAM J Optim* (2010). doi: [10.1137/080738970](https://doi.org/10.1137/080738970)

In [None]:
import numpy as np

#### 1. Let's construct a toy example of multi-contrast images

    Here, two contrasts are simulated, and each contrast image has the shape `[2, 3]`, i.e., 2 rows and 3 columns.


In [None]:
x1 = np.arange(1, 7, 1).reshape(2, 3) * 0.1 + 1
x2 = np.arange(1, 7, 1).reshape(2, 3) * 0.1 + 2

print('contrast #1: \n', x1)
print('contrast #2: \n', x2)

x = np.stack((x1, x2))

print('multi-contrast images: \n', x)


contrast #1: 
 [[1.1 1.2 1.3]
 [1.4 1.5 1.6]]
contrast #2: 
 [[2.1 2.2 2.3]
 [2.4 2.5 2.6]]
multi-contrast images: 
 [[[1.1 1.2 1.3]
  [1.4 1.5 1.6]]

 [[2.1 2.2 2.3]
  [2.4 2.5 2.6]]]


#### 2. extract local patches

    Here, we use the patch size of `[2, 2]`, and the stride size is `[1, 1]`.

###### The first extracted patch is then:

In [None]:
print('      contrast #1 \t\t\t   contrast #2')

print('\n(a) the 1st local patch masked by *')

print('    * * * * * * * \t\t\t * * * * * * *')
print('    * ' + str(x1[0, 0]) + '   ' + str(x1[0, 1]) + ' * ' + str(x1[0, 2]) + '\t\t\t ' +\
      '* ' + str(x2[0, 0]) + '   ' + str(x2[0, 1]) + ' * ' + str(x2[0, 2]))
print('    * ' + str(x1[1, 0]) + '   ' + str(x1[1, 1]) + ' * ' + str(x1[1, 2]) + '\t\t\t ' +\
      '* ' + str(x2[1, 0]) + '   ' + str(x2[1, 1]) + ' * ' + str(x2[1, 2]))
print('    * * * * * * * \t\t\t * * * * * * *')

print('\n(b) flatten')

print('    ' + str(x1[0, 0]) + ', ' + str(x2[0, 0]))
print('    ' + str(x1[0, 1]) + ', ' + str(x2[0, 1]))
print('    ' + str(x1[1, 0]) + ', ' + str(x2[1, 0]))
print('    ' + str(x1[1, 1]) + ', ' + str(x2[1, 1]))



print('\n(c) the 2nd local patch masked by *')

print('          * * * * * * * \t\t       * * * * * * *')
print('      ' + str(x1[0, 0]) + ' * ' + str(x1[0, 1]) + '   ' + str(x1[0, 2]) + ' * ' + '\t\t   ' +\
      str(x2[0, 0]) + ' * ' + str(x2[0, 1]) + '   ' + str(x2[0, 2]) + ' * ')
print('      ' + str(x1[1, 0]) + ' * ' + str(x1[1, 1]) + '   ' + str(x1[1, 2]) + ' * ' + '\t\t   ' +\
      str(x2[1, 0]) + ' * ' + str(x2[1, 1]) + '   ' + str(x2[1, 2]) + ' * ')
print('          * * * * * * * \t\t       * * * * * * *')

print('\n(d) flatten')

print('    ' + str(x1[0, 1]) + ', ' + str(x2[0, 1]))
print('    ' + str(x1[0, 2]) + ', ' + str(x2[0, 2]))
print('    ' + str(x1[1, 1]) + ', ' + str(x2[1, 1]))
print('    ' + str(x1[1, 2]) + ', ' + str(x2[1, 2]))


      contrast #1 			   contrast #2

(a) the 1st local patch masked by *
    * * * * * * * 			 * * * * * * *
    * 1.1   1.2 * 1.3			 * 2.1   2.2 * 2.3
    * 1.4   1.5 * 1.6			 * 2.4   2.5 * 2.6
    * * * * * * * 			 * * * * * * *

(b) flatten
    1.1, 2.1
    1.2, 2.2
    1.4, 2.4
    1.5, 2.5

(c) the 2nd local patch masked by *
          * * * * * * * 		       * * * * * * *
      1.1 * 1.2   1.3 * 		   2.1 * 2.2   2.3 * 
      1.4 * 1.5   1.6 * 		   2.4 * 2.5   2.6 * 
          * * * * * * * 		       * * * * * * *

(d) flatten
    1.2, 2.2
    1.3, 2.3
    1.5, 2.5
    1.6, 2.6


### 3. convert the local patches back to full-size 3D arrays

    Here, we skip the SVT step, and focus on why the conversion step needs scaling when using overlapping local patches.

    As you can see in the step above, the second column of each contrast appears in both patches. 

    As a result, during the conversion process, the overlapped location accumulates the values from every patch.

    The converted array then becomes

In [None]:
print('      contrast #1 \t\t\t   contrast #2')

print('\n(e) the converted array')

print('      ' + str(x1[0, 0]) + '   ' + str(x1[0, 1] + x1[0, 1]) + '   ' + str(x1[0, 2]) + '\t\t\t ' +\
      '  ' + str(x2[0, 0]) + '   ' + str(x2[0, 1] + x2[0, 1]) + '   ' + str(x2[0, 2]))
print('      ' + str(x1[1, 0]) + '   ' + str(x1[1, 1] + x1[1, 1]) + '   ' + str(x1[1, 2]) + '\t\t\t ' +\
      '  ' + str(x2[1, 0]) + '   ' + str(x2[1, 1] + x2[1, 1]) + '   ' + str(x2[1, 2]))

      contrast #1 			   contrast #2

(e) the converted array
      1.1   2.4   1.3			   2.1   4.4   2.3
      1.4   3.0   1.6			   2.4   5.0   2.6


### 4. Solution to overlapping LLR: 

(1) define an ones array with the same shape as the 3D multi-contrast array. Here, ones array means every element of the array is one.

(2) extract patches from the ones array, and then convert these patches back to arrays.

(3) the converted array then reflects how many times the overlapping happens!

In [None]:
print('      contrast #1 \t\t   contrast #2')

print('\n(e) the converted array')

print('      ' + str(1) + '   ' + str(2) + '   ' + str(1) + '\t\t\t ' +\
      '  ' + str(1) + '   ' + str(2) + '   ' + str(1))
print('      ' + str(1) + '   ' + str(2) + '   ' + str(1) + '\t\t\t ' +\
      '  ' + str(1) + '   ' + str(2) + '   ' + str(1))

      contrast #1 		   contrast #2

(e) the converted array
      1   2   1			   1   2   1
      1   2   1			   1   2   1


### 5. Now let's do it in sigpy!

###### 5.1 install sigpy

In [None]:
!git clone https://github.com/ZhengguoTan/sigpy.git

Cloning into 'sigpy'...
remote: Enumerating objects: 6070, done.[K
remote: Counting objects: 100% (606/606), done.[K
remote: Compressing objects: 100% (221/221), done.[K
remote: Total 6070 (delta 395), reused 534 (delta 373), pack-reused 5464[K
Receiving objects: 100% (6070/6070), 3.73 MiB | 10.53 MiB/s, done.
Resolving deltas: 100% (4415/4415), done.


In [None]:
%cd /content/sigpy
!git log -1

/content/sigpy
[33mcommit 9ac1824782788c9b36e43a78c77a139437597f16[m[33m ([m[1;36mHEAD -> [m[1;32mmaster[m[33m, [m[1;31morigin/master[m[33m, [m[1;31morigin/HEAD[m[33m)[m
Author: Zhengguo Tan <zhengguo.tan@gmail.com>
Date:   Thu Dec 22 14:55:23 2022 +0100

    add verbose option in ConjugateGradient


In [None]:
!pip install matplotlib
!pip install mpi4py
!pip install -e /content/sigpy

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mpi4py
  Downloading mpi4py-3.1.4.tar.gz (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m32.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: mpi4py
  Building wheel for mpi4py (pyproject.toml) ... [?25l[?25hdone
  Created wheel for mpi4py: filename=mpi4py-3.1.4-cp39-cp39-linux_x86_64.whl size=3380623 sha256=2aefa54d9d4cb9e32d89bd236cbfaa247c87df8c566d6894dcfd27f662167a79
  Stored in directory: /root/.cache/pip/wheels/db/81/9f/43a031fce121c845baca1c5d9a1468cad98208286aa2832de9
Successfully built mpi4py
Installing collected packages: mpi4py

###### 5.2 We will continue to use the toy example 3D array

In [None]:
import sigpy as sp
from sigpy import linop

In [None]:
# here we define the ArrayToBlocks linear operator

T = linop.ArrayToBlocks(x.shape, blk_shape=(2, 2), blk_strides=(1, 1))

In [None]:
# here we apply ArrayToBlocks and then its adjoing (i.e. BlocksToArray) to the multi-contrast 3D array

y = T.H * T * x

print('\n multi-cotrast 3D array x: \n', x)

print('\n converted multi-contrast 3D array y: \n', y)


 multi-cotrast 3D array x: 
 [[[1.1 1.2 1.3]
  [1.4 1.5 1.6]]

 [[2.1 2.2 2.3]
  [2.4 2.5 2.6]]]

 converted multi-contrast 3D array y: 
 [[[1.1 2.4 1.3]
  [1.4 3.  1.6]]

 [[2.1 4.4 2.3]
  [2.4 5.  2.6]]]


In [None]:
# now we compute the scaling matrix to correct the overlapping problem

I = np.ones_like(x)
S = T.H * T * I

print('\n 3D ones array I: \n', I)

print('\n converted 3D array accounting for overlapping S: \n', S)


 3D ones array I: 
 [[[1. 1. 1.]
  [1. 1. 1.]]

 [[1. 1. 1.]
  [1. 1. 1.]]]

 converted 3D array accounting for overlapping S: 
 [[[1. 2. 1.]
  [1. 2. 1.]]

 [[1. 2. 1.]
  [1. 2. 1.]]]


In [None]:
# now we can apply this caling matrix to correct for y
y_corr = (T.H * T * x) / S

print('\n multi-cotrast 3D array x: \n', x)

print('\n converted multi-contrast 3D array y_corr: \n', y_corr)


 multi-cotrast 3D array x: 
 [[[1.1 1.2 1.3]
  [1.4 1.5 1.6]]

 [[2.1 2.2 2.3]
  [2.4 2.5 2.6]]]

 converted multi-contrast 3D array y_corr: 
 [[[1.1 1.2 1.3]
  [1.4 1.5 1.6]]

 [[2.1 2.2 2.3]
  [2.4 2.5 2.6]]]


Thank you for your attention!