# SigPy for MRI Tutorial Part 1: Non Uniform Fast Fourier Transform (NUFFT)

Welcome!

In this notebook, we will go through basic features of SigPy using the NUFFT as an example usage. NUFFT is a core operation in non-Cartesian MRI, but still can be a computational bottleneck in many applications. We will show how you can easily change parameters and computing devices in SigPy to speed things up.

Before moving on to the tutorial, we want to point out our [documentation](https://sigpy.readthedocs.io) if you want to find out more information about each function.


## Setup
If you haven't installed SigPy already, please follow the [installation instructions](https://sigpy.readthedocs.io/en/latest/index.html#installation), and come back to this tutorial after.

SigPy is meant to use along with NumPy. In particular, SigPy operates on NumPy arrays directly, and relies on NumPy to provide basic data manipulation functions. Almost aways, we will import NumPy along with SigPy.

In addition to that, we will import the [sigpy.plot](https://sigpy.readthedocs.io/en/latest/plot.html) sub-module for plotting. The module provides convenient plotting functions for multi-dimensional arrays, using only hot-keys for control. We will be using them to show images, but not focus them in this tutorial.

Finally, the `%matplotlib notebook` magic allows us to get interactive plots in jupyter notebooks.

In [1]:
%matplotlib notebook
import numpy as np
import sigpy as sp
import sigpy.plot as pl

We will use a non-Cartesian dataset created by Prof. Martin Uecker for the [ISMRM reproducible challenge](https://blog.ismrm.org/2019/04/02/ismrm-reproducible-research-study-group-2019-reproduce-a-seminal-paper-initiative/). The dataset contains k-space measurements of a brain scan acquired with a projection reconstruction trajectory. We have re-saved the dataset into NumPy data format. Let us load the dataset and look at the array shapes.

In [2]:
ksp = np.load('data/radial_ksp.npy')
coord = np.load('data/radial_coord.npy')

print('k-space array shape: {}'.format(ksp.shape))
print('coordinate array shape: {}'.format(coord.shape))

k-space array shape: (12, 96, 512)
coordinate array shape: (96, 512, 2)


A few words about array shapes. Because NumPy by default stores arrays by row-major, the array shapes are flipped compared to what you would expect in Matlab or Fortran, which uses column-major. The k-space array has its shape arranged as number of coil channels, number of repetitions, and number of readout points. The coordinate array has its shape arranged as number of repetitions, number of readout points and number of dimensions. In this case, you can see this is a 2D dataset, because the last dimension of `coord` is 2.

## Gridding Reconstruction

For gridding reconstruction, we will need the density compensation factor (`dcf`). For the projection reconstruction trajectory, it is proportional to the radius of the k-space coordinates. 

Let us compute the density compensation factor and visualize it with the function [ScatterPlot](https://sigpy.readthedocs.io/en/latest/generated/sigpy.plot.ScatterPlot.html#sigpy.plot.ScatterPlot).

In [3]:
dcf = (coord[..., 0]**2 + coord[..., 1]**2)**0.5
pl.ScatterPlot(coord, dcf, hide_axes=True)

<IPython.core.display.Javascript object>

<sigpy.plot.ScatterPlot at 0x7f7833236390>

The imported arrays are NumPy arrays, so SigPy can operate on them directly. To perform an gridding reconstruction, we can simply call [nufft_adjoint](https://sigpy.readthedocs.io/en/latest/generated/sigpy.nufft_adjoint.html#sigpy.nufft_adjoint) on the density compensated input `ksp * dcf`. This gives us the gridded multi-channel images:

In [5]:
img_grid = sp.nufft_adjoint(ksp * dcf, coord)
pl.ImagePlot(img_grid, z=0, hide_axes=True)

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f7827401208>

We will combine the coil images by performing root-sum-of-squares along the coil dimension. This can be easily done using NumPy operations.

In [6]:
img_rss = np.sum(np.abs(img_grid)**2, axis=0)**0.5
pl.ImagePlot(img_rss, hide_axes=True)

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f78273b0b00>

The gridding reconstruction isn't too slow, but perhaps we might want it to be faster when we have a stack of images. One way we can make the reconstruction faster is tuning the oversampling ratios and kernel widths of [nufft_adjoint](https://sigpy.readthedocs.io/en/latest/generated/sigpy.nufft_adjoint.html#sigpy.nufft_adjoint).

We will prepend the magic command `% time` to the nufft_adjoint call to time it. Change the nufft parameters and see how they affect the run-time and artifacts!

In [7]:
%time img_grid_tune = sp.nufft_adjoint(ksp * dcf, coord, oversamp=1, width=2)

img_rss_tune = np.sum(np.abs(img_grid_tune)**2, axis=0)**0.5
pl.ImagePlot(img_rss_tune, hide_axes=True)

CPU times: user 882 ms, sys: 6.86 ms, total: 889 ms
Wall time: 73.5 ms


<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f7827401240>

## Gridding Reconstruction on GPU

SigPy allows you 

To run the NUFFT on GPU, all we have to do is to move the arrays to a GPU device, and wrap the function with a GPU device context. This is similar to how Tensorflow or Pytorch specify their computing device.

![device](https://sigpy.readthedocs.io/en/latest/_images/device.pdf)

In order to run the following code, you will need to have a GPU, and install the package `cupy`.

In [9]:
device = sp.Device(0)

ksp_gpu = sp.to_device(ksp, device=device)
coord_gpu = sp.to_device(coord, device=device)
dcf_gpu = sp.to_device(dcf, device=device)

with device:
    img_grid_gpu = sp.nufft_adjoint(ksp_gpu * dcf_gpu, coord_gpu)

pl.ImagePlot(img_grid_gpu, z=0, hide_axes=True)

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f78273b08d0>

In [10]:
xp = device.xp
with device:
    img_rss_gpu = xp.sum(xp.abs(img_grid_gpu)**2, axis=0)**0.5

pl.ImagePlot(img_rss_gpu, hide_axes=True)

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f7833236518>

In [6]:
import sigpy.mri as mr

mps = mr.app.JsenseRecon(ksp, coord=coord).run()

JsenseRecon:   0%|          | 0/10 [00:00<?, ?it/s]
LinearLeastSquares:   0%|          | 0/10 [00:00<?, ?it/s][A
LinearLeastSquares:   0%|          | 0/10 [00:00<?, ?it/s, resid=8.21E+01][A
LinearLeastSquares:  10%|█         | 1/10 [00:00<00:02,  3.92it/s, resid=8.21E+01][A
LinearLeastSquares:  10%|█         | 1/10 [00:00<00:02,  3.92it/s, resid=8.21E+01][A
LinearLeastSquares:  10%|█         | 1/10 [00:00<00:02,  3.92it/s, resid=2.46E+01][A
LinearLeastSquares:  20%|██        | 2/10 [00:00<00:02,  3.92it/s, resid=2.46E+01][A
LinearLeastSquares:  20%|██        | 2/10 [00:00<00:02,  3.92it/s, resid=8.80E+00][A
LinearLeastSquares:  30%|███       | 3/10 [00:00<00:01,  3.92it/s, resid=8.80E+00][A
LinearLeastSquares:  30%|███       | 3/10 [00:00<00:01,  3.92it/s, resid=3.14E+00][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:01,  3.92it/s, resid=3.14E+00][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:01,  3.92it/s, resid=1.36E+00][A
LinearLeastSquares:  50%|█████    

JsenseRecon:  20%|██        | 2/10 [00:01<00:08,  1.01s/it].94it/s, resid=4.57E-01][A
LinearLeastSquares:   0%|          | 0/10 [00:00<?, ?it/s][A
LinearLeastSquares:   0%|          | 0/10 [00:00<?, ?it/s, resid=4.45E-01][A
LinearLeastSquares:  10%|█         | 1/10 [00:00<00:00, 30.72it/s, resid=4.45E-01][A
LinearLeastSquares:  10%|█         | 1/10 [00:00<00:00, 16.69it/s, resid=2.99E-01][A
LinearLeastSquares:  20%|██        | 2/10 [00:00<00:00, 32.77it/s, resid=2.99E-01][A
LinearLeastSquares:  20%|██        | 2/10 [00:00<00:00, 20.42it/s, resid=1.66E-01][A
LinearLeastSquares:  30%|███       | 3/10 [00:00<00:00, 30.24it/s, resid=1.66E-01][A
LinearLeastSquares:  30%|███       | 3/10 [00:00<00:00, 23.49it/s, resid=8.06E-02][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:00, 31.10it/s, resid=8.06E-02][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:00, 31.10it/s, resid=8.06E-02][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:00, 31.10it/s, resid=3.94E-02][A

JsenseRecon:  40%|████      | 4/10 [00:03<00:05,  1.15it/s].44it/s, resid=2.73E-01][A
LinearLeastSquares:   0%|          | 0/10 [00:00<?, ?it/s][A
LinearLeastSquares:   0%|          | 0/10 [00:00<?, ?it/s, resid=2.29E-01][A
LinearLeastSquares:  10%|█         | 1/10 [00:00<00:00, 40.01it/s, resid=2.29E-01][A
LinearLeastSquares:  10%|█         | 1/10 [00:00<00:00, 18.77it/s, resid=1.31E-01][A
LinearLeastSquares:  20%|██        | 2/10 [00:00<00:00, 36.66it/s, resid=1.31E-01][A
LinearLeastSquares:  20%|██        | 2/10 [00:00<00:00, 23.92it/s, resid=8.20E-02][A
LinearLeastSquares:  30%|███       | 3/10 [00:00<00:00, 35.37it/s, resid=8.20E-02][A
LinearLeastSquares:  30%|███       | 3/10 [00:00<00:00, 25.97it/s, resid=4.55E-02][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:00, 34.37it/s, resid=4.55E-02][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:00, 34.37it/s, resid=4.55E-02][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:00, 34.37it/s, resid=2.22E-02][A

JsenseRecon:  60%|██████    | 6/10 [00:04<00:03,  1.30it/s].67it/s, resid=1.61E-01][A
LinearLeastSquares:   0%|          | 0/10 [00:00<?, ?it/s][A
LinearLeastSquares:   0%|          | 0/10 [00:00<?, ?it/s, resid=1.31E-01][A
LinearLeastSquares:  10%|█         | 1/10 [00:00<00:00, 36.79it/s, resid=1.31E-01][A
LinearLeastSquares:  10%|█         | 1/10 [00:00<00:00, 18.55it/s, resid=7.21E-02][A
LinearLeastSquares:  20%|██        | 2/10 [00:00<00:00, 36.25it/s, resid=7.21E-02][A
LinearLeastSquares:  20%|██        | 2/10 [00:00<00:00, 20.82it/s, resid=4.38E-02][A
LinearLeastSquares:  30%|███       | 3/10 [00:00<00:00, 30.78it/s, resid=4.38E-02][A
LinearLeastSquares:  30%|███       | 3/10 [00:00<00:00, 21.91it/s, resid=2.41E-02][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:00, 29.00it/s, resid=2.41E-02][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:00, 29.00it/s, resid=2.41E-02][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:00, 29.00it/s, resid=1.22E-02][A

JsenseRecon:  80%|████████  | 8/10 [00:06<00:01,  1.33it/s].16it/s, resid=1.31E-01][A
LinearLeastSquares:   0%|          | 0/10 [00:00<?, ?it/s][A
LinearLeastSquares:   0%|          | 0/10 [00:00<?, ?it/s, resid=9.44E-02][A
LinearLeastSquares:  10%|█         | 1/10 [00:00<00:00, 29.19it/s, resid=9.44E-02][A
LinearLeastSquares:  10%|█         | 1/10 [00:00<00:00, 10.00it/s, resid=5.25E-02][A
LinearLeastSquares:  20%|██        | 2/10 [00:00<00:00, 19.75it/s, resid=5.25E-02][A
LinearLeastSquares:  20%|██        | 2/10 [00:00<00:00, 19.75it/s, resid=5.25E-02][A
LinearLeastSquares:  20%|██        | 2/10 [00:00<00:00, 19.75it/s, resid=3.24E-02][A
LinearLeastSquares:  30%|███       | 3/10 [00:00<00:00, 19.75it/s, resid=3.24E-02][A
LinearLeastSquares:  30%|███       | 3/10 [00:00<00:00, 19.75it/s, resid=1.74E-02][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:00, 19.75it/s, resid=1.74E-02][A
LinearLeastSquares:  40%|████      | 4/10 [00:00<00:00, 19.75it/s, resid=9.06E-03][A

LinearLeastSquares: 100%|██████████| 10/10 [00:00<00:00, 33.97it/s, resid=1.16E-01][A
JsenseRecon: 100%|██████████| 10/10 [00:07<00:00,  1.36it/s]98it/s, resid=1.16E-01][A


In [7]:
pl.ImagePlot(mps)

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f1dc404a2b0>

In [26]:
lamda = 1e-6
img = mr.app.L1WaveletRecon(ksp, mps, coord=coord, lamda=lamda, device=0).run()


MaxEig:   0%|          | 0/30 [00:00<?, ?it/s][A
MaxEig:   0%|          | 0/30 [00:00<?, ?it/s, max_eig=3.99E+02][A
MaxEig:   3%|▎         | 1/30 [00:00<00:00, 44.69it/s, max_eig=3.99E+02][A
MaxEig:   3%|▎         | 1/30 [00:00<00:01, 24.76it/s, max_eig=4.75E+01][A
MaxEig:   7%|▋         | 2/30 [00:00<00:00, 47.40it/s, max_eig=4.75E+01][A
MaxEig:   7%|▋         | 2/30 [00:00<00:00, 33.19it/s, max_eig=1.05E+02][A
MaxEig:  10%|█         | 3/30 [00:00<00:00, 48.51it/s, max_eig=1.05E+02][A
MaxEig:  10%|█         | 3/30 [00:00<00:00, 37.44it/s, max_eig=1.17E+02][A
MaxEig:  13%|█▎        | 4/30 [00:00<00:00, 48.93it/s, max_eig=1.17E+02][A
MaxEig:  13%|█▎        | 4/30 [00:00<00:00, 39.96it/s, max_eig=1.23E+02][A
MaxEig:  17%|█▋        | 5/30 [00:00<00:00, 49.44it/s, max_eig=1.23E+02][A
MaxEig:  17%|█▋        | 5/30 [00:00<00:00, 49.44it/s, max_eig=1.23E+02][A
MaxEig:  17%|█▋        | 5/30 [00:00<00:00, 49.44it/s, max_eig=1.25E+02][A
MaxEig:  20%|██        | 6/30 [00:00<00:00, 4

L1WaveletRecon:  16%|█▌        | 16/100 [00:00<00:03, 22.30it/s, resid=6.64E-03][A
L1WaveletRecon:  16%|█▌        | 16/100 [00:00<00:03, 22.30it/s, resid=6.17E-03][A
L1WaveletRecon:  17%|█▋        | 17/100 [00:00<00:03, 22.66it/s, resid=6.17E-03][A
L1WaveletRecon:  17%|█▋        | 17/100 [00:00<00:03, 22.66it/s, resid=6.17E-03][A
L1WaveletRecon:  17%|█▋        | 17/100 [00:00<00:03, 22.66it/s, resid=5.76E-03][A
L1WaveletRecon:  18%|█▊        | 18/100 [00:00<00:03, 22.66it/s, resid=5.76E-03][A
L1WaveletRecon:  18%|█▊        | 18/100 [00:00<00:03, 22.66it/s, resid=5.39E-03][A
L1WaveletRecon:  19%|█▉        | 19/100 [00:00<00:03, 22.66it/s, resid=5.39E-03][A
L1WaveletRecon:  19%|█▉        | 19/100 [00:00<00:03, 22.66it/s, resid=5.07E-03][A
L1WaveletRecon:  20%|██        | 20/100 [00:00<00:03, 22.75it/s, resid=5.07E-03][A
L1WaveletRecon:  20%|██        | 20/100 [00:00<00:03, 22.75it/s, resid=5.07E-03][A
L1WaveletRecon:  20%|██        | 20/100 [00:00<00:03, 22.75it/s, resid=4.78E

L1WaveletRecon:  57%|█████▋    | 57/100 [00:02<00:01, 22.00it/s, resid=1.13E-03][A
L1WaveletRecon:  58%|█████▊    | 58/100 [00:02<00:01, 22.00it/s, resid=1.13E-03][A
L1WaveletRecon:  58%|█████▊    | 58/100 [00:02<00:01, 22.00it/s, resid=1.10E-03][A
L1WaveletRecon:  59%|█████▉    | 59/100 [00:02<00:01, 22.34it/s, resid=1.10E-03][A
L1WaveletRecon:  59%|█████▉    | 59/100 [00:02<00:01, 22.34it/s, resid=1.10E-03][A
L1WaveletRecon:  59%|█████▉    | 59/100 [00:02<00:01, 22.34it/s, resid=1.08E-03][A
L1WaveletRecon:  60%|██████    | 60/100 [00:02<00:01, 22.34it/s, resid=1.08E-03][A
L1WaveletRecon:  60%|██████    | 60/100 [00:02<00:01, 22.34it/s, resid=1.05E-03][A
L1WaveletRecon:  61%|██████    | 61/100 [00:02<00:01, 22.34it/s, resid=1.05E-03][A
L1WaveletRecon:  61%|██████    | 61/100 [00:02<00:01, 22.34it/s, resid=1.02E-03][A
L1WaveletRecon:  62%|██████▏   | 62/100 [00:02<00:01, 22.52it/s, resid=1.02E-03][A
L1WaveletRecon:  62%|██████▏   | 62/100 [00:02<00:01, 22.52it/s, resid=1.02E

L1WaveletRecon:  99%|█████████▉| 99/100 [00:04<00:00, 21.28it/s, resid=4.75E-04][A
L1WaveletRecon:  99%|█████████▉| 99/100 [00:04<00:00, 21.28it/s, resid=4.67E-04][A
L1WaveletRecon: 100%|██████████| 100/100 [00:04<00:00, 23.13it/s, resid=4.67E-04][A
L1WaveletRecon: 100%|██████████| 100/100 [00:04<00:00, 23.13it/s, resid=4.67E-04][A
[A

In [27]:
pl.ImagePlot(img)

<IPython.core.display.Javascript object>

<sigpy.plot.ImagePlot at 0x7f1db31fe390>