Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SHORT] Fix up codes for FFT class, making it more user friendly and fix grid search #111

Merged
merged 8 commits into from
Jan 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 67 additions & 27 deletions mri/operators/fourier/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# System import
import warnings
import numpy as np
import scipy as sp

# Package import
from ..base import OperatorBase
from .utils import convert_locations_to_mask
from .utils import convert_locations_to_mask, convert_mask_to_locations
from modopt.interface.errors import warn

# Third party import
Expand Down Expand Up @@ -45,32 +46,47 @@ class FFT(OperatorBase):
the mask samples in the Fourier domain.
shape: tuple of int
shape of the image (not necessarly a square matrix).
n_coils: int, default 1
Number of coils used to acquire the signal in case of multiarray
receiver coils acquisition. If n_coils > 1, data shape must be
[n_coils, Nx, Ny, NZ]
n_coils: int, default 1
Number of coils used to acquire the signal in case of multiarray
receiver coils acquisition. If n_coils > 1, data shape must be
[n_coils, Nx, Ny, NZ]
n_jobs: int, default 1
Number of parallel workers to use for fourier computation
"""
def __init__(self, samples, shape, n_coils=1):
def __init__(self, shape, n_coils=1, samples=None, mask=None, n_jobs=1):
""" Initilize the 'FFT' class.

Parameters
----------
samples: np.ndarray
the mask samples in the Fourier domain.
shape: tuple of int
shape of the image (not necessarly a square matrix).
n_coils: int, default 1
Number of coils used to acquire the signal in case of
multiarray receiver coils acquisition. If n_coils > 1,
data shape must be equal to [n_coils, Nx, Ny, NZ]
n_coils: int, default 1
Number of coils used to acquire the signal in case of
multiarray receiver coils acquisition. If n_coils > 1,
data shape must be equal to [n_coils, Nx, Ny, NZ]
samples: np.ndarray, default None
the mask samples in the Fourier domain.
mask: np.ndarray, default None
the mask as a matrix with 1 at sample locations
please pass samples or mask
zaccharieramzi marked this conversation as resolved.
Show resolved Hide resolved
n_jobs: int, default 1
Number of parallel workers to use for fourier computation
All cores are used if -1
"""
self.samples = samples
zaccharieramzi marked this conversation as resolved.
Show resolved Hide resolved
self.shape = shape
self._mask = convert_locations_to_mask(self.samples, self.shape)
if mask is None and samples is None:
raise ValueError("Please pass either samples or mask as input")
if mask is None:
self.mask = convert_locations_to_mask(samples, self.shape)
self.samples = samples
else:
self.mask = mask
self.samples = convert_mask_to_locations(mask)
if n_coils <= 0:
warn("The number of coils should be strictly positive")
n_coils = 1
self.n_coils = n_coils
self.n_jobs = n_jobs

def op(self, img):
""" This method calculates the masked Fourier transform of a ND image.
Expand All @@ -88,18 +104,30 @@ def op(self, img):
images the coils dimension is put first
"""
if self.n_coils == 1:
return self._mask * np.fft.ifftshift(np.fft.fftn(
np.fft.fftshift(img), norm="ortho"))
return self.mask * sp.fft.ifftshift(sp.fft.fftn(
sp.fft.fftshift(img),
norm="ortho",
workers=self.n_jobs,
))
else:
if self.n_coils > 1 and self.n_coils != img.shape[0]:
raise ValueError("The number of coils parameter is not equal"
"to the actual number of coils, the data must"
"be reshaped as [n_coils, Nx, Ny, Nz]")
else:
# TODO: Use joblib for parallelization
return np.asarray([self._mask * np.fft.ifftshift(np.fft.fftn(
np.fft.fftshift(img[ch]), norm="ortho"))
for ch in range(self.n_coils)])
axes = tuple(np.arange(1, img.ndim))
zaccharieramzi marked this conversation as resolved.
Show resolved Hide resolved
return self.mask * sp.fft.ifftshift(
sp.fft.fftn(
sp.fft.fftshift(
img,
axes=axes
),
axes=axes,
norm="ortho",
workers=self.n_jobs,
),
axes=axes
)

def adj_op(self, x):
""" This method calculates inverse masked Fourier transform of a ND
Expand All @@ -118,16 +146,28 @@ def adj_op(self, x):
For multichannel images the coils dimension is put first
"""
if self.n_coils == 1:
return np.fft.fftshift(np.fft.ifftn(
np.fft.ifftshift(self._mask * x), norm="ortho"))
return sp.fft.fftshift(sp.fft.ifftn(
sp.fft.ifftshift(self.mask * x),
norm="ortho",
workers=self.n_jobs,
))
else:
if self.n_coils > 1 and self.n_coils != x.shape[0]:
raise ValueError("The number of coils parameter is not equal"
"to the actual number of coils, the data must"
"be reshaped as [n_coils, Nx, Ny, Nz]")
else:
# TODO: Use joblib for parallelization
return np.asarray([np.fft.fftshift(np.fft.ifftn(
np.fft.ifftshift(self._mask * x[ch]),
norm="ortho"))
for ch in range(self.n_coils)])
x = x * self.mask
axes = tuple(np.arange(1, x.ndim))
return sp.fft.fftshift(
sp.fft.ifftn(
sp.fft.ifftshift(
x,
axes=axes
),
axes=axes,
norm="ortho",
workers=self.n_jobs,
),
axes=axes
)
1 change: 1 addition & 0 deletions mri/scripts/gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def launch_grid(kspace_data, reconstructor_class, reconstructor_kwargs,
{
'samples': fourier_op.samples,
'shape': fourier_op.shape,
'n_coils': fourier_op.n_coils,
}
}
fourier_op = None
Expand Down
7 changes: 5 additions & 2 deletions mri/tests/test_reconstructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def test_single_channel_reconstruction(self):
# mu is 0 for above single channel reconstruction and
# hence we expect the result to be the inverse fourier transform
np.testing.assert_allclose(
x_final, fourier_0.adj_op(data_0))
x_final, fourier_0.adj_op(data_0), rtol=1e-3)

def test_self_calibrating_reconstruction(self):
""" Test all the registered transformations.
Expand Down Expand Up @@ -202,7 +202,10 @@ def test_self_calibrating_reconstruction(self):
)
recon = fourier_0.adj_op(fourier_0.op(image_multichannel))
np.testing.assert_allclose(
np.abs(x_final), np.sqrt(np.sum(np.abs(recon)**2, axis=0)))
np.abs(x_final),
np.sqrt(np.sum(np.abs(recon)**2, axis=0)),
rtol=1e-3
)

def test_calibrationless_reconstruction(self):
""" Test all the registered transformations.
Expand Down