diff --git a/tensorflow/core/util/gpu_kernel_helper.h b/tensorflow/core/util/gpu_kernel_helper.h index fdb84e46420a99..cf98e623fa6226 100644 --- a/tensorflow/core/util/gpu_kernel_helper.h +++ b/tensorflow/core/util/gpu_kernel_helper.h @@ -71,6 +71,35 @@ using gpuError_t = hipError_t; #endif namespace tensorflow { +// Launches a GPU kernel through cudaLaunchKernel in Cuda environment, or +// hipLaunchKernel in ROCm environment with the given arguments. +// +// The kernel parameters 'Ts' must be constructible from the arguments 'Args'. +template +Status GpuLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim, + size_t shared_memory_size_bytes, gpuStream_t stream, + Args... arguments) { + static_assert(detail::NoneIsReference(), + "Kernels with reference arguments have undefined behaviour."); + auto func_ptr = absl::bit_cast(function); +#if GOOGLE_CUDA + // Cast arguments and forward them as an array of pointers. + auto args_tuple = std::tuple(arguments...); + auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple); + auto result = cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(), + shared_memory_size_bytes, stream); + if (result != cudaSuccess) { + return errors::Internal(cudaGetErrorString(result)); + } +#elif TENSORFLOW_USE_ROCM + hipLaunchKernelGGL(func_ptr, grid_dim, block_dim, shared_memory_size_bytes, + stream, std::forward(arguments)...); +#endif + return Status::OK(); +} +#endif +CREATE_CUDA_HOST_FUNCTION_ALIAS(GpuLaunchKernel, CudaLaunchKernel); + __host__ __device__ inline tensorflow::bfloat16 GpuLdg( const tensorflow::bfloat16* address) { tensorflow::bfloat16 return_value; diff --git a/tensorflow/core/util/gpu_launch_config.h b/tensorflow/core/util/gpu_launch_config.h index 3cb61aa7ef546c..4ba2866e8e625c 100644 --- a/tensorflow/core/util/gpu_launch_config.h +++ b/tensorflow/core/util/gpu_launch_config.h @@ -400,31 +400,7 @@ constexpr bool NoneIsReference() { return NoneTrue<(std::is_reference::value)...>::value; } } // namespace detail - -#if GOOGLE_CUDA -// Launches a CUDA kernel through cudaLaunchKernel with the given arguments. -// -// The kernel parameters 'Ts' must be constructible from the arguments 'Args'. -template -Status CudaLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim, - size_t shared_memory_size_bytes, cudaStream_t stream, - Args... arguments) { - static_assert(detail::NoneIsReference(), - "Kernels with reference arguments have undefined behaviour."); - // Cast arguments and forward them as an array of pointers. - auto args_tuple = std::tuple(arguments...); - auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple); - auto func_ptr = absl::bit_cast(function); - auto result = cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(), - shared_memory_size_bytes, stream); - if (result != cudaSuccess) { - return errors::Internal(cudaGetErrorString(result)); - } - return Status::OK(); -} -#endif // GOOGLE_CUDA - } // namespace tensorflow #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM -#endif // TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_ +#endif // TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_