Skip to content

Commit

Permalink
Creating GpuLaunchKernel
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryyin committed May 7, 2019
1 parent 4660f6d commit 0b8e0ae
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 25 deletions.
29 changes: 29 additions & 0 deletions tensorflow/core/util/gpu_kernel_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,35 @@ using gpuError_t = hipError_t;
#endif

namespace tensorflow {
// Launches a GPU kernel through cudaLaunchKernel in Cuda environment, or
// hipLaunchKernel in ROCm environment with the given arguments.
//
// The kernel parameters 'Ts' must be constructible from the arguments 'Args'.
template <typename... Ts, typename... Args>
Status GpuLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim,
size_t shared_memory_size_bytes, gpuStream_t stream,
Args... arguments) {
static_assert(detail::NoneIsReference<Ts...>(),
"Kernels with reference arguments have undefined behaviour.");
auto func_ptr = absl::bit_cast<const void*>(function);
#if GOOGLE_CUDA
// Cast arguments and forward them as an array of pointers.
auto args_tuple = std::tuple<Ts...>(arguments...);
auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple);
auto result = cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(),
shared_memory_size_bytes, stream);
if (result != cudaSuccess) {
return errors::Internal(cudaGetErrorString(result));
}
#elif TENSORFLOW_USE_ROCM
hipLaunchKernelGGL(func_ptr, grid_dim, block_dim, shared_memory_size_bytes,
stream, std::forward<Args>(arguments)...);
#endif
return Status::OK();
}
#endif
CREATE_CUDA_HOST_FUNCTION_ALIAS(GpuLaunchKernel, CudaLaunchKernel);

__host__ __device__ inline tensorflow::bfloat16 GpuLdg(
const tensorflow::bfloat16* address) {
tensorflow::bfloat16 return_value;
Expand Down
26 changes: 1 addition & 25 deletions tensorflow/core/util/gpu_launch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -400,31 +400,7 @@ constexpr bool NoneIsReference() {
return NoneTrue<(std::is_reference<Ts>::value)...>::value;
}
} // namespace detail

#if GOOGLE_CUDA
// Launches a CUDA kernel through cudaLaunchKernel with the given arguments.
//
// The kernel parameters 'Ts' must be constructible from the arguments 'Args'.
template <typename... Ts, typename... Args>
Status CudaLaunchKernel(void (*function)(Ts...), dim3 grid_dim, dim3 block_dim,
size_t shared_memory_size_bytes, cudaStream_t stream,
Args... arguments) {
static_assert(detail::NoneIsReference<Ts...>(),
"Kernels with reference arguments have undefined behaviour.");
// Cast arguments and forward them as an array of pointers.
auto args_tuple = std::tuple<Ts...>(arguments...);
auto arg_ptrs = detail::GetArrayOfElementPointers(&args_tuple);
auto func_ptr = absl::bit_cast<const void*>(function);
auto result = cudaLaunchKernel(func_ptr, grid_dim, block_dim, arg_ptrs.data(),
shared_memory_size_bytes, stream);
if (result != cudaSuccess) {
return errors::Internal(cudaGetErrorString(result));
}
return Status::OK();
}
#endif // GOOGLE_CUDA

} // namespace tensorflow

#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#endif // TENSORFLOW_CORE_UTIL_GPU_KERNEL_HELPER_H_
#endif // TENSORFLOW_CORE_UTIL_GPU_LAUNCH_CONFIG_H_

0 comments on commit 0b8e0ae

Please sign in to comment.