diff --git a/pyproject.toml b/pyproject.toml index 71baee3982..fdb771cb49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,7 @@ gpu_102 = ["pycuda", "cupy-cuda102", "cufinufft==1.3"] gpu_110 = ["pycuda", "cupy-cuda110", "cufinufft==1.3"] gpu_111 = ["pycuda", "cupy-cuda111", "cufinufft==1.3"] gpu_11x = ["pycuda", "cupy-cuda11x", "cufinufft==1.3"] -gpu_12x = ["pycuda", "cupy-cuda12x", "cufinufft==1.3"] +gpu_12x = ["pycuda", "cupy-cuda12x", "cufinufft==2.2.0"] dev = [ "black", "bumpversion", diff --git a/src/aspire/nufft/cufinufft.py b/src/aspire/nufft/cufinufft.py index dba4cdf4a5..84c757a786 100644 --- a/src/aspire/nufft/cufinufft.py +++ b/src/aspire/nufft/cufinufft.py @@ -28,12 +28,6 @@ def __init__(self, sz, fourier_pts, epsilon=1e-8, ntransforms=1, **kwargs): # Passing "ntransforms" > 1 expects one large higher dimensional array later. self.ntransforms = ntransforms - # Workaround cufinufft A100 singles issue - # ASPIRE-Python/703 - # Cast to doubles. - self._original_dtype = fourier_pts.dtype - fourier_pts = fourier_pts.astype(np.float64, copy=False) - # Basic dtype passthough. dtype = fourier_pts.dtype if dtype == np.float64 or dtype == np.complex128: @@ -102,12 +96,8 @@ def transform(self, signal): `(ntransforms, num_pts)`. """ - # Check we're not forcing a dtype workaround for ASPIRE-Python/703, - # then check if we have a dtype mismatch. - # This avoids false positive complaint for the workaround. - if (self._original_dtype == self.dtype) and not ( - signal.dtype == self.dtype or signal.dtype == self.complex_dtype - ): + # Check dtype mismatch. + if not (signal.dtype == self.dtype or signal.dtype == self.complex_dtype): logger.warning( "Incorrect dtypes passed to (a)nufft." " In the future this will be an error." @@ -143,8 +133,6 @@ def transform(self, signal): self._transform_plan.execute(result_gpu, signal_gpu) result = result_gpu.get() - # ASPIRE-Python/703 - result = result.astype(complex_type(self._original_dtype), copy=False) return result