In [1]:
import numpy as np
import cupy as cp

In [2]:
x_gpu = cp.array([1, 2, 3])
x_gpu

array([1, 2, 3])

In [3]:
import cupy as cp
from cupyx.scipy.fft import get_fft_plan

a = cp.random.random((4, 400, 400)).astype(cp.complex64)
plan = get_fft_plan(a, axes=(1, 2), value_type='C2C')  # for batched, C2C, 2D transform

In [4]:
with plan:
    # the arguments must match those used when generating the plan
    out = cp.fft.fft2(a, axes=(1, 2))

In [5]:
out.shape

(4, 400, 400)

In [6]:
with plan:
    %timeit cp.fft.fft2(a, axes=(1, 2))

149 μs ± 1.63 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [7]:
from sympy import *

In [8]:
# define symbols
x = Symbol('x')
y = cos(x) * exp(-x**2 / 25)
f = lambdify(x, y, 'numpy')
dy_x = diff(y, x)
simplify(dy_x)
df = lambdify(x, dy_x, 'numpy')

In [9]:
n = 1024
L = 30
dx = L/n
X = np.arange(-L/2, L/2, dx, dtype=np.complex128)
k = np.fft.fftshift((2*np.pi/L) * np.arange(-n/2, n/2))

In [10]:
u = f(X)
du = df(X)

In [11]:
u_cp = cp.asarray(u)
k_cp = cp.asarray(k)

In [12]:
def fft_derivative(u_cp, k_cp):
    uhat = cp.fft.fft(u_cp)
    duhat = k_cp * uhat * (1.j)
    duFFT_cp = cp.fft.ifft(duhat).real
    return duFFT_cp

In [13]:
@cp.fuse()
def multiply(x, y):
    return (x*y)

In [14]:
ik_cp = 1.j * k_cp

In [15]:
def fft_derivative2(u_cp, ik_cp):
    uhat = cp.fft.fft(u_cp)
    duhat = multiply(ik_cp, uhat)
    duFFT_cp = cp.fft.ifft(duhat).real
    return duFFT_cp

In [16]:
duFFT_cp = fft_derivative(u_cp, k_cp)

In [17]:
duFFT_cp2 = fft_derivative2(u_cp, ik_cp)

In [18]:
%timeit -n10000 fft_derivative(u_cp, k_cp)

96.1 μs ± 3.82 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [19]:
%timeit -n10000 fft_derivative2(u_cp, ik_cp), 10000

69 μs ± 1.61 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [20]:
# plot A with plotly
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(x=X.real, y=u_cp.get().real, mode='lines', name='real'))
# fig.add_trace(go.Scatter(x=X, y=A.imag, mode='lines', name='imag'))

# title
fig.update_layout(title=f'f(x) = {y}', xaxis_title='X', yaxis_title='A')
fig.show()

In [21]:
# plot A with plotly
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(x=X.real, y=du.real, mode='lines', name='analytic'))
fig.add_trace(go.Scatter(x=X.real, y=duFFT_cp.get().real, mode='lines', name='FFT'))
# fig.add_trace(go.Scatter(x=X, y=A.imag, mode='lines', name='imag'))

# title
fig.update_layout(title=f'f\'(x) = {dy_x}', xaxis_title='X', yaxis_title='A')
fig.show()

In [22]:
# plot A with plotly
import plotly.graph_objects as go
fig = go.Figure()
fig.add_trace(go.Scatter(x=X.real, y=du.real, mode='lines', name='analytic'))
fig.add_trace(go.Scatter(x=X.real, y=duFFT_cp2.get().real, mode='lines', name='FFT'))
# fig.add_trace(go.Scatter(x=X, y=A.imag, mode='lines', name='imag'))

# title
fig.update_layout(title=f'f\'(x) = {dy_x}', xaxis_title='X', yaxis_title='A')
fig.show()

In [23]:
squared_diff = cp.ElementwiseKernel(
   'float32 x, float32 y',
   'float32 z',
   'z = (x - y) * (x - y)',
   'squared_diff')

In [24]:
x = cp.arange(10, dtype=np.float32).reshape(2, 5)
y = cp.arange(5, dtype=np.float32)
squared_diff(x, y)

array([[ 0.,  0.,  0.,  0.,  0.],
       [25., 25., 25., 25., 25.]], dtype=float32)

In [25]:
x

array([[0., 1., 2., 3., 4.],
       [5., 6., 7., 8., 9.]], dtype=float32)

In [26]:
squared_diff_generic = cp.ElementwiseKernel(
    'T x, T y',
    'T z',
    'z = (x - y) * (x - y)',
    'squared_diff_generic')

In [27]:
x = cp.arange(10, dtype=np.complex128).reshape(2, 5)+1j*(5-cp.arange(5, dtype=np.complex128))
y = cp.arange(5, dtype=np.complex128)
squared_diff_generic(x, y)

array([[-25. +0.j, -16. +0.j,  -9. +0.j,  -4. +0.j,  -1. +0.j],
       [  0.+50.j,   9.+40.j,  16.+30.j,  21.+20.j,  24.+10.j]])

In [28]:
x

array([[0.+5.j, 1.+4.j, 2.+3.j, 3.+2.j, 4.+1.j],
       [5.+5.j, 6.+4.j, 7.+3.j, 8.+2.j, 9.+1.j]])

In [29]:
y

array([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j])

In [30]:
(x-y)**2

array([[-2.500000e+01+3.06161700e-15j, -1.600000e+01+1.95943488e-15j,
        -9.000000e+00+1.10218212e-15j, -4.000000e+00+4.89858720e-16j,
        -1.000000e+00+1.22464680e-16j],
       [ 3.061617e-15+5.00000000e+01j,  9.000000e+00+4.00000000e+01j,
         1.600000e+01+3.00000000e+01j,  2.100000e+01+2.00000000e+01j,
         2.400000e+01+1.00000000e+01j]])

In [31]:
@cp.fuse()
def squared_diff(x, y):
    return (x - y) * (x - y)

In [32]:
x_cp = cp.arange(10)
y_cp = cp.arange(10)[::-1]
squared_diff(x_cp, y_cp)

array([81, 49, 25,  9,  1,  1,  9, 25, 49, 81])

In [33]:
x_np = np.arange(10)
y_np = np.arange(10)[::-1]
squared_diff(x_np, y_np)

array([81, 49, 25,  9,  1,  1,  9, 25, 49, 81])

In [34]:
%timeit squared_diff(x_cp, y_cp)

19.1 μs ± 3.82 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
