Skip to content

Commit

Permalink
Merge pull request tensorflow#54 from ROCmSoftwarePlatform/deven_unit…
Browse files Browse the repository at this point in the history
…_test_fixes_180626

Special casing GpuAtomicMin / GpuAtomicMax for ROCm
  • Loading branch information
whchung committed Jun 27, 2018
2 parents e703ec9 + b2c2cd3 commit 640398b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 2 deletions.
58 changes: 58 additions & 0 deletions tensorflow/core/util/gpu_device_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -648,7 +648,34 @@ template <typename T, typename U>
__device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMax(T* ptr, U value) {
return atomicMax(ptr, value);
}

#if TENSORFLOW_USE_ROCM

/*
* CUDA runtime headers have the following defined
* __device__ int max(int, int)
* __device__ float max(float, float)
* __device__ double max(double, double)
*
* and many others, where as HIP runtime headers only have the "int" version
*
* Therefore need to special case ROCm version to call the correct underlying
* routines for float and double types.
*
*/

__device__ inline float GpuAtomicMax(float* ptr, float value) {
return detail::GpuAtomicCasHelper(
ptr, [value](float a) { return fmaxf(a, value); });
}

__device__ inline double GpuAtomicMax(double* ptr, double value) {
return detail::GpuAtomicCasHelper(
ptr, [value](double a) { return fmax(a, value); });
}

#else

__device__ inline float GpuAtomicMax(float* ptr, float value) {
return detail::GpuAtomicCasHelper(
ptr, [value](float a) { return max(a, value); });
Expand All @@ -659,6 +686,8 @@ __device__ inline double GpuAtomicMax(double* ptr, double value) {
ptr, [value](double a) { return max(a, value); });
}

#endif

__device__ inline Eigen::half GpuAtomicMax(Eigen::half* ptr,
Eigen::half value) {
return detail::GpuAtomicCasHelper(
Expand All @@ -678,7 +707,34 @@ template <typename T, typename U>
__device__ detail::ToTypeIfConvertible<U, T> GpuAtomicMin(T* ptr, U value) {
return atomicMin(ptr, value);
}

#if TENSORFLOW_USE_ROCM

/*
* CUDA runtime headers have the following defined
* __device__ int min(int, int)
* __device__ float min(float, float)
* __device__ double min(double, double)
*
* and many others, where as HIP runtime headers only have the "int" version
*
* Therefore need to special case ROCm version to call the correct underlying
* routines for float and double types.
*
*/

__device__ inline float GpuAtomicMin(float* ptr, float value) {
return detail::GpuAtomicCasHelper(
ptr, [value](float a) { return fminf(a, value); });
}

__device__ inline double GpuAtomicMin(double* ptr, double value) {
return detail::GpuAtomicCasHelper(
ptr, [value](double a) { return fmin(a, value); });
}

#else

__device__ inline float GpuAtomicMin(float* ptr, float value) {
return detail::GpuAtomicCasHelper(
ptr, [value](float a) { return min(a, value); });
Expand All @@ -689,6 +745,8 @@ __device__ inline double GpuAtomicMin(double* ptr, double value) {
ptr, [value](double a) { return min(a, value); });
}

#endif

__device__ inline Eigen::half GpuAtomicMin(Eigen::half* ptr,
Eigen::half value) {
return detail::GpuAtomicCasHelper(
Expand Down
2 changes: 0 additions & 2 deletions tensorflow/tools/ci_build/linux/rocm/run_py3_core.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ bazel test --test_sharding_strategy=disabled --config=rocm --test_tag_filters=-n
-//tensorflow/python/kernel_tests:pool_test \
-//tensorflow/python/kernel_tests:pooling_ops_3d_test \
-//tensorflow/python/kernel_tests:pooling_ops_test \
-//tensorflow/python/kernel_tests:reduction_ops_test \
-//tensorflow/python/kernel_tests:scatter_ops_test \
-//tensorflow/python/profiler/internal:run_metadata_test \
-//tensorflow/python/profiler:profile_context_test \
-//tensorflow/python/profiler:profiler_test \
Expand Down

0 comments on commit 640398b

Please sign in to comment.