Skip to content

Commit

Permalink
Merge pull request #12 from DifferentiableUniverseInitiative/fix_fft
Browse files Browse the repository at this point in the history
Fix 3D FFT
  • Loading branch information
ASKabalan committed Jun 11, 2024
2 parents e246d43 + 83554e0 commit d2fb678
Show file tree
Hide file tree
Showing 13 changed files with 102 additions and 72 deletions.
3 changes: 2 additions & 1 deletion jaxdecomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
TRANSPOSE_COMM_MPI_P2P_PL, TRANSPOSE_COMM_NCCL, TRANSPOSE_COMM_NCCL_PL,
TRANSPOSE_COMM_NVSHMEM, TRANSPOSE_COMM_NVSHMEM_PL, HaloCommBackend,
TransposeCommBackend, finalize, get_autotuned_config, get_pencil_info,
halo_exchange, make_config, slice_pad, slice_unpad)
halo_exchange, init, make_config, slice_pad, slice_unpad)

try:
__version__ = version("jaxDecomp")
Expand All @@ -20,6 +20,7 @@

__all__ = [
"config",
"init",
"finalize",
"get_pencil_info",
"get_autotuned_config",
Expand Down
1 change: 1 addition & 0 deletions jaxdecomp/_src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from . import _jaxdecomp

init = _jaxdecomp.init
finalize = _jaxdecomp.finalize
get_pencil_info = _jaxdecomp.get_pencil_info
get_autotuned_config = _jaxdecomp.get_autotuned_config
Expand Down
81 changes: 46 additions & 35 deletions jaxdecomp/_src/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,39 +75,31 @@ def sfft_abstract_eval(x, fft_type, pdims, global_shape, adjoint):
axis = 2

output_shape = None
match fft_type:
case xla_client.FftType.FFT:
# FFT is X to Y to Z so Z-Pencil is returned
# Except if we are doing a YZ slab in which case we return a Y-Pencil
transpose_shape = (1, 2, 0)
transposed_pdims = pdims
case xla_client.FftType.IFFT:
# IFFT is Z to X to Y so X-Pencil is returned
# In YZ slab case we only need one transposition back to get the X-Pencil
transpose_shape = (2, 0, 1)
transposed_pdims = pdims
case _:
raise TypeError("only complex FFTs are currently supported through pfft.")

expected_slice_shape = (global_shape[0] // pdims[1],
global_shape[1] // pdims[0], global_shape[2])
# Are we operating on the global array?
# This is called when the abstract_eval of the custom partitioning is called _custom_partitioning_abstract_eval in https://github.com/google/jax/blob/main/jax/experimental/custom_partitioning.py#L223
if x.shape == global_shape:
# only works for cubes
# TODO(wassim) The transpose has to be the same as the slices maybe?
output_shape = x.shape
shape = tuple([global_shape[i] for i in transpose_shape])
output_shape = shape
# Or are we operating on a local slice?
# this is called JAX calls make_jaxpr(lower_fn) in https://github.com/google/jax/blob/main/jax/experimental/custom_partitioning.py#L142C5-L142C35
elif x.shape == expected_slice_shape:

output_shape = expected_slice_shape

# why do this? there has to be an easier way to check validity of the shape
config = _jaxdecomp.GridConfig()

config.pdims = pdims
# Dimensions are actually in reverse order due to Fortran indexing at the cuDecomp level
config.gdims = global_shape[::-1]
config.halo_comm_backend = jaxdecomp.config.halo_comm_backend
config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend
# Dimensions are actually in reverse order due to Fortran indexing at the cuDecomp level
pencil = _jaxdecomp.get_pencil_info(config, axis)
shape = pencil.shape[::-1]
assert np.prod(shape) == np.prod(
x.shape
), "Only array dimensions divisible by the process mesh size are currently supported. The current configuration leads to local slices of varying sizes between forward and reverse FFT."

# This should never happen
else:
assert False, f"Invalid shape for the input array Expected either the global shape : {global_shape} or the local slice shape : {expected_slice_shape} but got {x.shape}"
output_shape = (global_shape[transpose_shape[0]] // transposed_pdims[1],
global_shape[transpose_shape[1]] // transposed_pdims[0],
global_shape[transpose_shape[2]])

# Sanity check
assert (output_shape is not None)
Expand All @@ -119,8 +111,6 @@ def sfft_lowering(ctx, a, *, fft_type, pdims, global_shape, adjoint):
(aval_out,) = ctx.avals_out
dtype = x_aval.dtype
a_type = ir.RankedTensorType(a.type)
n = len(a_type.shape)

# We currently only support complex FFTs through this interface, so let's check the fft type
assert fft_type in (FftType.FFT,
FftType.IFFT), "Only complex FFTs are currently supported"
Expand All @@ -129,8 +119,16 @@ def sfft_lowering(ctx, a, *, fft_type, pdims, global_shape, adjoint):
forward = fft_type in (FftType.FFT,)
is_double = np.finfo(dtype).dtype == np.float64

# global_shape pdims have been provided by the sharding in the custom partitioning definition
# TODO(wassim) maybe they should None by default which would discourage users from calling sfft directly
# Get original global shape
match fft_type:
case xla_client.FftType.FFT:
transpose_back_shape = (0, 1, 2)
case xla_client.FftType.IFFT:
transpose_back_shape = (2, 0, 1)
case _:
raise TypeError("only complex FFTs are currently supported through pfft.")
# Make sure to get back the original shape of the X-Pencil
global_shape = tuple([global_shape[i] for i in transpose_back_shape])
# Compute the descriptor for our FFT
config = _jaxdecomp.GridConfig()

Expand All @@ -140,6 +138,8 @@ def sfft_lowering(ctx, a, *, fft_type, pdims, global_shape, adjoint):
config.transpose_comm_backend = jaxdecomp.config.transpose_comm_backend
workspace_size, opaque = _jaxdecomp.build_fft_descriptor(
config, forward, is_double, adjoint)

n = len(a_type.shape)
layout = tuple(range(n - 1, -1, -1))

# We ask XLA to allocate a workspace for this operation.
Expand All @@ -161,15 +161,16 @@ def sfft_lowering(ctx, a, *, fft_type, pdims, global_shape, adjoint):
)

# Finally we reshape the arry to the expected shape.
return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), result).results
out_type = ir.RankedTensorType.get(aval_out.shape, a_type.element_type)
return hlo.ReshapeOp(out_type, result).results


def _fft_transpose_rule(x, operand, fft_type, pdims, global_shape, adjoint):
assert fft_type in [FftType.FFT, FftType.IFFT]
if fft_type == FftType.FFT:
result = sfft(x, FftType.IFFT, pdims, global_shape, ~adjoint)
result = sfft(x, FftType.IFFT, ~adjoint, pdims, global_shape)
elif fft_type == FftType.IFFT:
result = sfft(x, FftType.FFT, pdims, global_shape, ~adjoint)
result = sfft(x, FftType.FFT, ~adjoint, pdims, global_shape)
else:
raise NotImplementedError

Expand Down Expand Up @@ -219,7 +220,6 @@ def partition(fft_type, adjoint, mesh, arg_shapes, result_shape):
input_sharding = arg_shapes[0].sharding

def lower_fn(operand):

# Operand is a local slice and arg_shapes contains the global shape
# No need to retranpose in the relowered function because abstract eval understands sliced input
# and in the original lowering we use aval.out
Expand All @@ -230,6 +230,17 @@ def lower_fn(operand):
pdims = (get_axis_size(input_sharding, 1), get_axis_size(input_sharding, 0))

output = sfft(operand, fft_type, adjoint, pdims, global_shape)

# This is supposed to let us avoid making an extra transpose in the YZ case
# it does not work
# # In case of YZ slab the cuda code tranposes only once
# # We transpose again to give back the Z-Pencil to the user in case of FFT and the X-Pencil in case of IFFT
# # this transposition is supposed to compiled out by XLA when doing a gradient (forward followed by backward)
# if get_axis_size(input_sharding, 0) == 1:
# if fft_type == FftType.FFT:
# output = output.transpose((1, 2, 0))
# elif fft_type == FftType.IFFT:
# output = output.transpose((2, 0, 1))
return output

return mesh, lower_fn, \
Expand Down Expand Up @@ -257,7 +268,7 @@ def infer_sharding_from_operands(fft_type, adjoint, mesh, arg_shapes,
"""
# only one operand is used in pfft
input_sharding = arg_shapes[0].sharding
return to_named_sharding(input_sharding)
return NamedSharding(mesh, P(*input_sharding.spec))


@partial(custom_partitioning, static_argnums=(1, 2))
Expand Down
14 changes: 2 additions & 12 deletions jaxdecomp/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,8 @@ def _do_pfft(


def pfft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array:
return _do_pfft(
"fft",
xla_client.FftType.FFT,
a,
norm=norm,
)
return _do_pfft("fft", xla_client.FftType.FFT, a, norm=norm)


def pifft3d(a: ArrayLike, norm: Optional[str] = "backward") -> Array:
return _do_pfft(
"ifft",
xla_client.FftType.IFFT,
a,
norm=norm,
)
return _do_pfft("ifft", xla_client.FftType.IFFT, a, norm=norm)
1 change: 1 addition & 0 deletions scripts/autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jaxdecomp

# Initialize jax distributed to instruct jax local process which GPU to use
jaxdecomp.init()
jax.distributed.initialize()
rank = jax.process_index()

Expand Down
1 change: 1 addition & 0 deletions scripts/test_fft3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jaxdecomp

# Initialize jax distributed to instruct jax local process which GPU to use
jaxdecomp.init()
jax.distributed.initialize()
rank = jax.process_index()

Expand Down
9 changes: 9 additions & 0 deletions src/fft.cu
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,11 @@ HRESULT FourierExecutor<real_t>::forwardYZ(cudecompHandle_t handle, fftDescripto
// FFT on the second slab
CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_yz, output, output, DIRECTION));

// Extra Y to Z transpose to give back a Z pencil to the user

CHECK_CUDECOMP_EXIT(cudecompTransposeYToZ(handle, m_GridConfig, output, output, work_d,
get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream));

return S_OK;
}

Expand All @@ -375,6 +380,10 @@ HRESULT FourierExecutor<real_t>::backwardYZ(cudecompHandle_t handle, fftDescript
CHECK_CUFFT_EXIT(cufftSetWorkArea(m_Plan_c2c_x, work_d));
CHECK_CUFFT_EXIT(cufftSetWorkArea(m_Plan_c2c_yz, work_d));

// Input is Z pencil tranposed it back to Y pencil
CHECK_CUDECOMP_EXIT(cudecompTransposeZToY(handle, m_GridConfig, input, output, work_d,
get_cudecomp_datatype(complex_t(0)), nullptr, nullptr, stream));

// FFT on the first slab
CHECK_CUFFT_EXIT(cufftXtExec(m_Plan_c2c_yz, input, output, DIRECTION));
// Tranpose Y to X
Expand Down
5 changes: 5 additions & 0 deletions src/jaxdecomp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ namespace jaxdecomp {
// library is imported, and then implicitly reused in all functions
// cudecompHandle_t handle;

/**
* @brief Initializes the global handle
*/
void init() { jd::GridDescriptorManager::getInstance(); };
/**
* @brief Finalizes the cuDecomp library
*/
Expand Down Expand Up @@ -311,6 +315,7 @@ py::dict Registrations() {

PYBIND11_MODULE(_jaxdecomp, m) {
// Utilities
m.def("init", &jd::init);
m.def("finalize", &jd::finalize);
m.def("get_pencil_info", &jd::getPencilInfo);
m.def("get_autotuned_config", &jd::getAutotunedGridConfig);
Expand Down
1 change: 1 addition & 0 deletions tests/test_allgather.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from numpy.testing import assert_array_equal

# Initialize jax distributed to instruct jax local process which GPU to use
jaxdecomp.init()
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
Expand Down
55 changes: 31 additions & 24 deletions tests/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def create_spmd_array(global_shape, pdims):
],
key=jax.random.PRNGKey(rank))
# Remap to the global array from the local slice
devices = mesh_utils.create_device_mesh(pdims[::-1])
mesh = Mesh(devices, axis_names=('z', 'y'))
devices = mesh_utils.create_device_mesh(pdims)
mesh = Mesh(devices, axis_names=('y', 'z'))
global_array = multihost_utils.host_local_array_to_global_array(
local_array, mesh, P('z', 'y'))

Expand All @@ -46,18 +46,21 @@ def create_spmd_array(global_shape, pdims):
pencil_1 = (size // 2, size // (size // 2)) # 2x2 for V100 and 4x2 for A100
pencil_2 = (size // (size // 2), size // 2) # 2x2 for V100 and 2x4 for A100

decomp = [(size, 1), (1, size), pencil_1, pencil_2]
global_shapes = [(4, 8, 16), (4, 4, 4), (29 * size, 19 * size, 17 * size)
] # Cubes, non-cubes and primes

@pytest.mark.parametrize(
"pdims",
[(1, size),
(size, 1), pencil_1, pencil_2]) # Test with Slab and Pencil decompositions
def test_fft(pdims):

# Cartesian product tests
@pytest.mark.parametrize("pdims",
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_fft(pdims, global_shape):

print("*" * 80)
print(f"Testing with pdims {pdims}")
print(f"Testing with pdims {pdims} and global shape {global_shape}")

global_shape = (29 * size, 19 * size, 17 * size
) # These sizes are prime numbers x size of the pmesh
global_array, mesh = create_spmd_array(global_shape, pdims)

# Perform distributed FFT
Expand All @@ -70,31 +73,35 @@ def test_fft(pdims):
gathered_array = multihost_utils.process_allgather(global_array, tiled=True)
gathered_karray = multihost_utils.process_allgather(karray, tiled=True)
gathered_rec_array = multihost_utils.process_allgather(rec_array, tiled=True)
jax_karray = jnp.fft.fftn(gathered_array).transpose([1, 2, 0])
jax_rec_array = jnp.fft.ifftn(jax_karray).transpose([2, 0, 1])
jax_karray = jnp.fft.fftn(gathered_array)

# Check reconstructed array
assert_allclose(
gathered_array.real, gathered_rec_array.real, rtol=1e-7, atol=1e-7)
assert_allclose(
gathered_array.imag, gathered_rec_array.imag, rtol=1e-7, atol=1e-7)
# Check the reverse FFT

# Check the forward FFT
transpose_back = [1, 2, 0]
jax_karray_transposed = jax_karray.transpose(transpose_back)
assert_allclose(
gathered_rec_array.real, jax_rec_array.real, rtol=1e-7, atol=1e-7)
gathered_karray.real, jax_karray_transposed.real, rtol=1e-7, atol=1e-7)
assert_allclose(
gathered_rec_array.imag, jax_rec_array.imag, rtol=1e-7, atol=1e-7)
gathered_karray.imag, jax_karray_transposed.imag, rtol=1e-7, atol=1e-7)


# Cartesian product tests
@pytest.mark.parametrize("pdims",
decomp) # Test with Slab and Pencil decompositions
@pytest.mark.parametrize("global_shape",
global_shapes) # Test cubes, non-cubes and primes
def test_grad(pdims, global_shape):

@pytest.mark.parametrize(
"pdims",
[(1, size),
(size, 1), pencil_1, pencil_2]) # Test with Slab and Pencil decompositions
def test_grad(pdims):
transpose_back = [2, 0, 1]

print("*" * 80)
print(f"Testing with pdims {pdims}")
print(f"Testing with pdims {pdims} and global shape {global_shape}")

global_shape = (29 * size, 19 * size, 17 * size
) # These sizes are prime numbers x size of the pmesh
global_array, mesh = create_spmd_array(global_shape, pdims)

print("-" * 40)
Expand All @@ -109,7 +116,7 @@ def spmd_grad(arr):
# Perform local FFT
@jax.jit
def local_grad(arr):
y = jnp.fft.fftn(arr).transpose([1, 2, 0])
y = jnp.fft.fftn(arr).transpose(transpose_back)
y = (y * jnp.conjugate(y)).real.sum()
return y

Expand Down
1 change: 1 addition & 0 deletions tests/test_halo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from jaxdecomp import slice_pad, slice_unpad

# Initialize jax distributed to instruct jax local process which GPU to use
jaxdecomp.init()
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
Expand Down
1 change: 1 addition & 0 deletions tests/test_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from jaxdecomp._src.padding import slice_pad, slice_unpad

# Initialize jax distributed to instruct jax local process which GPU to use
jaxdecomp.init()
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
Expand Down
1 change: 1 addition & 0 deletions tests/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jaxdecomp

jaxdecomp.init()
jax.distributed.initialize()
rank = jax.process_index()
size = jax.process_count()
Expand Down

0 comments on commit d2fb678

Please sign in to comment.