Skip to content

Commit

Permalink
[SHORT] Fix up codes for FFT class, making it more user friendly and …
Browse files Browse the repository at this point in the history
…fix grid search (#111)

* Fix FFT without for loops

* Fix fft op and allow sending mask

* PEP8

* Move to scipy

* Add workers

* Add n_jobs documentation

* Add samples

* Fix test and add minute change for grid search

Co-authored-by: chaithyagr <chaithyagr@gitlab.com>
  • Loading branch information
chaithyagr and chaithyagr committed Jan 13, 2021
1 parent 18e11fe commit 7b24267
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 29 deletions.
94 changes: 67 additions & 27 deletions mri/operators/fourier/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# System import
import warnings
import numpy as np
import scipy as sp

# Package import
from ..base import OperatorBase
from .utils import convert_locations_to_mask
from .utils import convert_locations_to_mask, convert_mask_to_locations
from modopt.interface.errors import warn

# Third party import
Expand Down Expand Up @@ -45,32 +46,47 @@ class FFT(OperatorBase):
the mask samples in the Fourier domain.
shape: tuple of int
shape of the image (not necessarly a square matrix).
n_coils: int, default 1
Number of coils used to acquire the signal in case of multiarray
receiver coils acquisition. If n_coils > 1, data shape must be
[n_coils, Nx, Ny, NZ]
n_coils: int, default 1
Number of coils used to acquire the signal in case of multiarray
receiver coils acquisition. If n_coils > 1, data shape must be
[n_coils, Nx, Ny, NZ]
n_jobs: int, default 1
Number of parallel workers to use for fourier computation
"""
def __init__(self, samples, shape, n_coils=1):
def __init__(self, shape, n_coils=1, samples=None, mask=None, n_jobs=1):
""" Initilize the 'FFT' class.
Parameters
----------
samples: np.ndarray
the mask samples in the Fourier domain.
shape: tuple of int
shape of the image (not necessarly a square matrix).
n_coils: int, default 1
Number of coils used to acquire the signal in case of
multiarray receiver coils acquisition. If n_coils > 1,
data shape must be equal to [n_coils, Nx, Ny, NZ]
n_coils: int, default 1
Number of coils used to acquire the signal in case of
multiarray receiver coils acquisition. If n_coils > 1,
data shape must be equal to [n_coils, Nx, Ny, NZ]
samples: np.ndarray, default None
the mask samples in the Fourier domain.
mask: np.ndarray, default None
the mask as a matrix with 1 at sample locations
please pass samples or mask
n_jobs: int, default 1
Number of parallel workers to use for fourier computation
All cores are used if -1
"""
self.samples = samples
self.shape = shape
self._mask = convert_locations_to_mask(self.samples, self.shape)
if mask is None and samples is None:
raise ValueError("Please pass either samples or mask as input")
if mask is None:
self.mask = convert_locations_to_mask(samples, self.shape)
self.samples = samples
else:
self.mask = mask
self.samples = convert_mask_to_locations(mask)
if n_coils <= 0:
warn("The number of coils should be strictly positive")
n_coils = 1
self.n_coils = n_coils
self.n_jobs = n_jobs

def op(self, img):
""" This method calculates the masked Fourier transform of a ND image.
Expand All @@ -88,18 +104,30 @@ def op(self, img):
images the coils dimension is put first
"""
if self.n_coils == 1:
return self._mask * np.fft.ifftshift(np.fft.fftn(
np.fft.fftshift(img), norm="ortho"))
return self.mask * sp.fft.ifftshift(sp.fft.fftn(
sp.fft.fftshift(img),
norm="ortho",
workers=self.n_jobs,
))
else:
if self.n_coils > 1 and self.n_coils != img.shape[0]:
raise ValueError("The number of coils parameter is not equal"
"to the actual number of coils, the data must"
"be reshaped as [n_coils, Nx, Ny, Nz]")
else:
# TODO: Use joblib for parallelization
return np.asarray([self._mask * np.fft.ifftshift(np.fft.fftn(
np.fft.fftshift(img[ch]), norm="ortho"))
for ch in range(self.n_coils)])
axes = tuple(np.arange(1, img.ndim))
return self.mask * sp.fft.ifftshift(
sp.fft.fftn(
sp.fft.fftshift(
img,
axes=axes
),
axes=axes,
norm="ortho",
workers=self.n_jobs,
),
axes=axes
)

def adj_op(self, x):
""" This method calculates inverse masked Fourier transform of a ND
Expand All @@ -118,16 +146,28 @@ def adj_op(self, x):
For multichannel images the coils dimension is put first
"""
if self.n_coils == 1:
return np.fft.fftshift(np.fft.ifftn(
np.fft.ifftshift(self._mask * x), norm="ortho"))
return sp.fft.fftshift(sp.fft.ifftn(
sp.fft.ifftshift(self.mask * x),
norm="ortho",
workers=self.n_jobs,
))
else:
if self.n_coils > 1 and self.n_coils != x.shape[0]:
raise ValueError("The number of coils parameter is not equal"
"to the actual number of coils, the data must"
"be reshaped as [n_coils, Nx, Ny, Nz]")
else:
# TODO: Use joblib for parallelization
return np.asarray([np.fft.fftshift(np.fft.ifftn(
np.fft.ifftshift(self._mask * x[ch]),
norm="ortho"))
for ch in range(self.n_coils)])
x = x * self.mask
axes = tuple(np.arange(1, x.ndim))
return sp.fft.fftshift(
sp.fft.ifftn(
sp.fft.ifftshift(
x,
axes=axes
),
axes=axes,
norm="ortho",
workers=self.n_jobs,
),
axes=axes
)
1 change: 1 addition & 0 deletions mri/scripts/gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def launch_grid(kspace_data, reconstructor_class, reconstructor_kwargs,
{
'samples': fourier_op.samples,
'shape': fourier_op.shape,
'n_coils': fourier_op.n_coils,
}
}
fourier_op = None
Expand Down
7 changes: 5 additions & 2 deletions mri/tests/test_reconstructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_single_channel_reconstruction(self):
# mu is 0 for above single channel reconstruction and
# hence we expect the result to be the inverse fourier transform
np.testing.assert_allclose(
x_final, fourier_0.adj_op(data_0))
x_final, fourier_0.adj_op(data_0), rtol=1e-3)

def test_self_calibrating_reconstruction(self):
""" Test all the registered transformations.
Expand Down Expand Up @@ -202,7 +202,10 @@ def test_self_calibrating_reconstruction(self):
)
recon = fourier_0.adj_op(fourier_0.op(image_multichannel))
np.testing.assert_allclose(
np.abs(x_final), np.sqrt(np.sum(np.abs(recon)**2, axis=0)))
np.abs(x_final),
np.sqrt(np.sum(np.abs(recon)**2, axis=0)),
rtol=1e-3
)

def test_calibrationless_reconstruction(self):
""" Test all the registered transformations.
Expand Down

0 comments on commit 7b24267

Please sign in to comment.