From fc663e348a448b3ae47d52579750060617820777 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 2 Jun 2020 09:41:44 -0400 Subject: [PATCH] Add support for 64-bit FFTs. (#3290) --- WORKSPACE | 6 +++--- jax/lax/lax_fft.py | 13 ++++++++----- jaxlib/version.py | 2 +- tests/fft_test.py | 5 ++--- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 35c967415cc6..9173c9f6f5c7 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -28,10 +28,10 @@ http_archive( # and update the sha256 with the result. http_archive( name = "org_tensorflow", - sha256 = "642f5a1bc191dfb96b2d7ed1cfb8f2a1515b5169b8de4381c75193cef8404b92", - strip_prefix = "tensorflow-b25fb1fe32094b60f5a53ad5f986ad65a9f05919", + sha256 = "99231c027ad22e1a82866d2e6bc60379d06d0a75793ac09b547282eb5b382d37", + strip_prefix = "tensorflow-37aaafb0c1baa7acd0607748326cc12faf556277", urls = [ - "https://github.com/tensorflow/tensorflow/archive/b25fb1fe32094b60f5a53ad5f986ad65a9f05919.tar.gz", + "https://github.com/tensorflow/tensorflow/archive/37aaafb0c1baa7acd0607748326cc12faf556277.tar.gz", ], ) diff --git a/jax/lax/lax_fft.py b/jax/lax/lax_fft.py index 1beaaa29bb7f..ffec855590bf 100644 --- a/jax/lax/lax_fft.py +++ b/jax/lax/lax_fft.py @@ -23,6 +23,7 @@ from jax.interpreters import xla from jax.util import prod from . import dtypes, lax +from .. import lib from ..lib import xla_client from ..interpreters import ad from ..interpreters import batching @@ -35,16 +36,18 @@ ] def _promote_to_complex(arg): - dtype = onp.result_type(arg, onp.complex64) - # XLA's FFT op only supports C64. - if dtype == onp.complex128: + dtype = dtypes.result_type(arg, onp.complex64) + # XLA's FFT op only supports C64 in jaxlib versions 0.1.47 and earlier. + # TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer. + if lib.version <= (0, 1, 47) and dtype == onp.complex128: dtype = onp.complex64 return lax.convert_element_type(arg, dtype) def _promote_to_real(arg): - dtype = onp.result_type(arg, onp.float64) + dtype = dtypes.result_type(arg, onp.float64) # XLA's FFT op only supports F32. - if dtype == onp.float64: + # TODO(phawkins): remove when minimum jaxlib version is 0.1.48 or newer. + if lib.version <= (0, 1, 47) and dtype == onp.float64: dtype = onp.float32 return lax.convert_element_type(arg, dtype) diff --git a/jaxlib/version.py b/jaxlib/version.py index fd2c7cd04050..16d2ca22732e 100644 --- a/jaxlib/version.py +++ b/jaxlib/version.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.1.47" +__version__ = "0.1.48" diff --git a/tests/fft_test.py b/tests/fft_test.py index 7500af891aec..10e831b5f74d 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -20,6 +20,7 @@ from absl.testing import absltest from absl.testing import parameterized +import jax from jax import lax from jax import numpy as jnp from jax import test_util as jtu @@ -29,9 +30,7 @@ float_dtypes = [np.float32, np.float64] -# TODO(b/144573940): np.complex128 isn't supported by XLA, and the JAX -# implementation casts to complex64. -complex_dtypes = [np.complex64] +complex_dtypes = [np.complex64, np.complex128] inexact_dtypes = float_dtypes + complex_dtypes int_dtypes = [np.int32, np.int64] bool_dtypes = [np.bool_]