Skip to content

BUG: double precision C2R IFFT intepreted as Z2R, but it should be Z2D #47

@kennykos

Description

@kennykos

Problem discription

My goal is to perform a Fourier convolutions on a 4-way tensor like so:

  1. R2C FFT (3x, fixing mode 1);
  2. Diagonal scaling (3x, fixing mode 1);
  3. C2R IFFT (3x, fixing mode 1).

Steps 1 looks like

for (i) in range(tensor.shape[0]):
    a[:] = tensor[1:]
    fft(a)
    a[:] = 0

and step 3 is similar. However, my minimal working example (which works in single precision)

import cupy as cp
from mpi4py import MPI 

import nvmath.distributed

# Initialize nvmath.distributed.
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
nranks = comm.Get_size()
device_id = rank % cp.cuda.runtime.getDeviceCount()
nvmath.distributed.initialize(device_id, comm)

fft_shape = 512 // nranks, 256, 512 
ifft_shape = 512, 256 // nranks, 512 // 2 + 1 

a = nvmath.distributed.fft.allocate_operand(
    fft_shape,
    cp, 
    input_dtype=cp.float64,
    distribution=nvmath.distributed.fft.Slab.X,
    fft_type="R2C",
)
b = nvmath.distributed.fft.allocate_operand(
    ifft_shape,
    cp, 
    input_dtype=cp.complex128,
    distribution=nvmath.distributed.fft.Slab.Y,
    fft_type="C2R",
)
with cp.cuda.Device(device_id):
    a[:] = cp.random.rand(*fft_shape, dtype=cp.float64)

b[:] = nvmath.distributed.fft.rfft(a, distribution=nvmath.distributed.fft.Slab.X, options={"reshape": False})

c = nvmath.distributed.fft.irfft(b, distribution=nvmath.distributed.fft.Slab.Y, options={"reshape": False})

nvmath.distributed.free_symmetric_memory(a)
nvmath.distributed.free_symmetric_memory(b)

breaks with

Traceback (most recent call last):
  File "/work/09661/gkk345/vista/tmp/dist_fft_c2r_r2c.py", line 35, in <module>
    c = nvmath.distributed.fft.irfft(b, distribution=nvmath.distributed.fft.Slab.Y, options={"reshape": False})
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/distributed/fft/fft.py", line 2435, in irfft
Traceback (most recent call last):
  File "/work/09661/gkk345/vista/tmp/dist_fft_c2r_r2c.py", line 35, in <module>
    c = nvmath.distributed.fft.irfft(b, distribution=nvmath.distributed.fft.Slab.Y, options={"reshape": False})
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/distributed/fft/fft.py", line 2435, in irfft
    return _fft(
           ^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/distributed/fft/fft.py", line 2236, in _fft
    return _fft(
           ^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/distributed/fft/fft.py", line 2236, in _fft
    fftobj.plan(stream=stream)
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/internal/utils.py", line 588, in inner
    fftobj.plan(stream=stream)
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/internal/utils.py", line 588, in inner
    result = wrapped_function(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/internal/utils.py", line 554, in inner
    result = wrapped_function(*args, **kwargs)
    raise e
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/internal/utils.py", line 546, in inner
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/internal/utils.py", line 554, in inner
    result = wrapped_function(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/distributed/fft/fft.py", line 1520, in plan
    raise e
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/internal/utils.py", line 546, in inner
    fft_concrete_type = _get_fft_concrete_type(self.operand_data_type, self.fft_abstract_type)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/distributed/fft/fft.py", line 190, in _get_fft_concrete_type
    result = wrapped_function(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/distributed/fft/fft.py", line 1520, in plan
    return FFTType["Z2R"]
           ~~~~~~~^^^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/enum.py", line 790, in __getitem__
    fft_concrete_type = _get_fft_concrete_type(self.operand_data_type, self.fft_abstract_type)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/site-packages/nvmath/distributed/fft/fft.py", line 190, in _get_fft_concrete_type
    return FFTType["Z2R"]
           ~~~~~~~^^^^^^^
  File "/work/09661/gkk345/vista/miniconda3/envs/sph/lib/python3.11/enum.py", line 790, in __getitem__
    return cls._member_map_[name]
           ~~~~~~~~~~~~~~~~^^^^^^
KeyError: 'Z2R'
ERROR:root:Symmetric heap memory needs to be deallocated explicitly
    return cls._member_map_[name]
ERROR:root:Symmetric heap memory needs to be deallocated explicitly
           ~~~~~~~~~~~~~~~~^^^^^^
KeyError: 'Z2R'
ERROR:root:Symmetric heap memory needs to be deallocated explicitly
ERROR:root:Symmetric heap memory needs to be deallocated explicitly

That is, it looks like my double precision C2R IFFT is being interpreted as Z2R IFFT, not a Z2D IFFT as it should be.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions