Skip to content

Commit

Permalink
Speedup of all fft calls
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentRDC committed Jan 14, 2021
1 parent 37220db commit ef3fe31
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Changelog
Release 2.1.2
-------------

* Speedup of :func:`autocenter`, :func:`align`, :func:`ialign`, and :func:`itrack_peak` by 50%.
* Speedup of all routines that use the Fast Fourier transform (:func:`autocenter`, :func:`align`, :func:`ialign`, :func:`itrack_peak`, and :func:`kinematicsim`) by 50%.

Release 2.1.1
-------------
Expand Down
24 changes: 13 additions & 11 deletions skued/simulation/kinematic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
"""

import numpy as np
from scipy.fftpack import next_fast_len
from os import cpu_count
from scipy.fft import next_fast_len, fft2, ifft2, fftshift, fftfreq, set_workers
from .potential import pelectrostatic
from ..eproperties import interaction_parameter
from scipy.interpolate import RegularGridInterpolator
Expand Down Expand Up @@ -52,13 +53,14 @@ def kinematicsim(crystal, kx, ky, energy=90):
potential = pelectrostatic(crystal, xx, yy)
transmission_function = np.exp(1j * interaction_parameter(energy) * potential)

exit_wave = np.fft.ifft2(
np.fft.fft2(np.ones_like(xx, dtype=np.complex) * transmission_function)
)
intensity = np.fft.fftshift(np.abs(np.fft.fft2(exit_wave)) ** 2)
with set_workers(cpu_count()):
exit_wave = ifft2(
fft2(np.ones_like(xx, dtype=np.complex) * transmission_function)
)
intensity = fftshift(np.abs(fft2(exit_wave)) ** 2)

kx_ = np.fft.fftshift(kx_)
ky_ = np.fft.fftshift(ky_)
kx_ = fftshift(kx_)
ky_ = fftshift(ky_)

# Note that the definition of 'frequency' in fftfreq & friends necessitates dividing by 2pi
twopi = 2 * np.pi
Expand Down Expand Up @@ -103,8 +105,8 @@ def fft2freq(x, y, indexing="xy"):
spacing_x = abs(extent_x[1] - extent_x[0])
spacing_y = abs(extent_y[1] - extent_y[0])

freqs_x = np.fft.fftfreq(len(extent_x), d=spacing_x)
freqs_y = np.fft.fftfreq(len(extent_y), d=spacing_y)
freqs_x = fftfreq(len(extent_x), d=spacing_x)
freqs_y = fftfreq(len(extent_y), d=spacing_y)

return np.meshgrid(freqs_x, freqs_y, indexing=indexing)

Expand All @@ -127,6 +129,6 @@ def limit_bandwidth(image, K, limit):
limited : `~numpy.ndarray`
Bandwidth-limited image.
"""
image_fft = np.fft.fft2(image)
image_fft = fft2(image)
image_fft[K > limit] = 0.0
return np.fft.ifft2(image_fft)
return ifft2(image_fft)

0 comments on commit ef3fe31

Please sign in to comment.