Skip to content

Commit

Permalink
Monkey patch scipy.fft to change defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentRDC committed Jan 14, 2021
1 parent ef3fe31 commit 232aad7
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 46 deletions.
46 changes: 46 additions & 0 deletions skued/fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
"""
Common module for fast fourier transforms.
"""

import scipy.fft
from functools import wraps
from scipy.fft import _pocketfft
from os import cpu_count

CPU_COUNT = cpu_count()


class SkuedPocketFFTBackend:
"""
FFT backend entirely based on Scipy's PocketFFT, with better defaults.
The speed of two-dimensional transforms may be improved by up to 50% with these
different defaults.
See also
--------
with_skued_fft
"""

__ua_domain__ = "numpy.scipy.fft"

@staticmethod
def __ua_function__(method, args, kwargs):
fn = getattr(_pocketfft, method.__name__, None)

if fn is None:
return NotImplemented
workers = kwargs.pop("workers", CPU_COUNT)
return fn(*args, workers=workers, **kwargs)


def with_skued_fft(f):
""" Ensure the use of the SkuedPocketFFTBackend whenever the `scipy.fft` module is used. """

@wraps(f)
def newf(*args, **kwargs):
with scipy.fft.set_backend(SkuedPocketFFTBackend):
return f(*args, **kwargs)

return newf
38 changes: 17 additions & 21 deletions skued/image/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from warnings import warn
from npstreams import array_stream
from scipy import ndimage as ndi
from scipy.fft import set_workers
from ..fft import with_skued_fft
from skimage.registration import phase_cross_correlation


@array_stream
@with_skued_fft
def itrack_peak(images, row_slice=None, col_slice=None, precision=1 / 10):
"""
Generator function that tracks a diffraction peak in a stream of images.
Expand Down Expand Up @@ -55,17 +56,16 @@ def itrack_peak(images, row_slice=None, col_slice=None, precision=1 / 10):
# scikit-image will use scipy.fft module
# so we can increase the number of FFT workers
# to get a performance speedup (50% in my tests)
with set_workers(cpu_count()):
for image in images:
sub[:] = image[row_slice, col_slice]
for image in images:
sub[:] = image[row_slice, col_slice]

shift = phase_cross_correlation(
reference_image=ref,
moving_image=sub,
return_error=False,
upsample_factor=int(1 / precision),
)
yield np.asarray(shift)
shift = with_skued_fft(phase_cross_correlation)(
reference_image=ref,
moving_image=sub,
return_error=False,
upsample_factor=int(1 / precision),
)
yield np.asarray(shift)


def align(image, reference, mask=None, fill_value=0.0):
Expand All @@ -92,16 +92,12 @@ def align(image, reference, mask=None, fill_value=0.0):
--------
ialign : generator of aligned images
"""
# scikit-image will use scipy.fft module
# so we can increase the number of FFT workers
# to get a performance speedup (50% in my tests)
with set_workers(cpu_count()):
shift = phase_cross_correlation(
reference_image=reference,
moving_image=image,
reference_mask=mask,
return_error=False,
)
shift = with_skued_fft(phase_cross_correlation)(
reference_image=reference,
moving_image=image,
reference_mask=mask,
return_error=False,
)
return ndi.shift(image, shift=shift, order=2, mode="constant", cval=fill_value)


Expand Down
18 changes: 7 additions & 11 deletions skued/image/center.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np
from skimage.registration import phase_cross_correlation
from scipy.ndimage import shift
from scipy.fft import set_workers
from ..fft import with_skued_fft


def autocenter(im, mask=None):
Expand Down Expand Up @@ -54,16 +54,12 @@ def autocenter(im, mask=None):
im_i = _fast_radial_inversion(im, center=(r_rough, c_rough), cval=0.0)
mask_i = _fast_radial_inversion(mask, center=(r_rough, c_rough), cval=False)

# scikit-image will use scipy.fft module
# so we can increase the number of FFT workers
# to get a performance speedup (50% in my tests)
with set_workers(cpu_count()):
shift = phase_cross_correlation(
reference_image=im,
moving_image=im_i,
reference_mask=mask,
moving_mask=mask_i,
)
shift = with_skued_fft(phase_cross_correlation)(
reference_image=im,
moving_image=im_i,
reference_mask=mask,
moving_mask=mask_i,
)

return np.array([r_rough, c_rough]) + shift / 2 - np.array([1 / 2, 1 / 2])

Expand Down
29 changes: 15 additions & 14 deletions skued/simulation/kinematic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@

import numpy as np
from os import cpu_count
from scipy.fft import next_fast_len, fft2, ifft2, fftshift, fftfreq, set_workers
import scipy.fft as fft
from ..fft import with_skued_fft
from .potential import pelectrostatic
from ..eproperties import interaction_parameter
from scipy.interpolate import RegularGridInterpolator


@with_skued_fft
def kinematicsim(crystal, kx, ky, energy=90):
"""
Propagate a plane wave through a crystal and compute the resulting
Expand All @@ -33,7 +35,7 @@ def kinematicsim(crystal, kx, ky, energy=90):
diff_pattern : `~numpy.ndarray`
Scattered intensity.
"""
shape = tuple(map(next_fast_len, kx.shape))
shape = tuple(map(fft.next_fast_len, kx.shape))
period_x, period_y, period_z = crystal.periodicity

# We create the grid ourselves so that we minimize Fourier artifacts as much as possible.
Expand All @@ -47,20 +49,19 @@ def kinematicsim(crystal, kx, ky, energy=90):
extent_y,
indexing="xy",
)
kx_, ky_ = fft2freq(xx, yy, indexing="xy")
kx_, ky_ = fft.fft2freq(xx, yy, indexing="xy")
k = np.hypot(kx_, ky_)

potential = pelectrostatic(crystal, xx, yy)
transmission_function = np.exp(1j * interaction_parameter(energy) * potential)

with set_workers(cpu_count()):
exit_wave = ifft2(
fft2(np.ones_like(xx, dtype=np.complex) * transmission_function)
)
intensity = fftshift(np.abs(fft2(exit_wave)) ** 2)
exit_wave = fft.ifft2(
fft.fft2(np.ones_like(xx, dtype=np.complex) * transmission_function)
)
intensity = fft.fftshift(np.abs(fft.fft2(exit_wave)) ** 2)

kx_ = fftshift(kx_)
ky_ = fftshift(ky_)
kx_ = fft.fftshift(kx_)
ky_ = fft.fftshift(ky_)

# Note that the definition of 'frequency' in fftfreq & friends necessitates dividing by 2pi
twopi = 2 * np.pi
Expand Down Expand Up @@ -105,8 +106,8 @@ def fft2freq(x, y, indexing="xy"):
spacing_x = abs(extent_x[1] - extent_x[0])
spacing_y = abs(extent_y[1] - extent_y[0])

freqs_x = fftfreq(len(extent_x), d=spacing_x)
freqs_y = fftfreq(len(extent_y), d=spacing_y)
freqs_x = fft.fftfreq(len(extent_x), d=spacing_x)
freqs_y = fft.fftfreq(len(extent_y), d=spacing_y)

return np.meshgrid(freqs_x, freqs_y, indexing=indexing)

Expand All @@ -129,6 +130,6 @@ def limit_bandwidth(image, K, limit):
limited : `~numpy.ndarray`
Bandwidth-limited image.
"""
image_fft = fft2(image)
image_fft = fft.fft2(image)
image_fft[K > limit] = 0.0
return ifft2(image_fft)
return fft.ifft2(image_fft)

0 comments on commit 232aad7

Please sign in to comment.