Skip to content

Commit

Permalink
add block size conf
Browse files Browse the repository at this point in the history
  • Loading branch information
guo-ran committed May 31, 2021
1 parent 820c909 commit 5f34b13
Showing 1 changed file with 48 additions and 4 deletions.
52 changes: 48 additions & 4 deletions oneflow/core/cuda/softmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ inline bool TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream, FETCH
const int64_t rows, const int64_t cols) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(typename GetComputeType<T>::type);
int max_active_blocks_conf_1;
int max_active_blocks_conf_2;
Expand All @@ -491,8 +493,28 @@ inline bool TryDispatchSoftmaxBlockSMemImplBlockSize(cudaStream_t stream, FETCH
SoftmaxBlockSMemImpl<FETCH, STORE, T, pack_size, block_size_conf_2>, block_size_conf_2,
smem));
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
LaunchSoftmaxBlockSMemImpl<FETCH, STORE, T, pack_size, block_size_conf_2>(stream, fetch, store,
smem, rows, cols);
int max_active_blocks_conf_3;
OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_3,
SoftmaxBlockSMemImpl<FETCH, STORE, T, pack_size, block_size_conf_3>, block_size_conf_3,
smem));
if (max_active_blocks_conf_3 == max_active_blocks_conf_2) {
int max_active_blocks_conf_4;
OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_4,
SoftmaxBlockSMemImpl<FETCH, STORE, T, pack_size, block_size_conf_4>, block_size_conf_4,
smem));
if (max_active_blocks_conf_4 == max_active_blocks_conf_3) {
LaunchSoftmaxBlockSMemImpl<FETCH, STORE, T, pack_size, block_size_conf_4>(
stream, fetch, store, smem, rows, cols);
} else {
LaunchSoftmaxBlockSMemImpl<FETCH, STORE, T, pack_size, block_size_conf_3>(
stream, fetch, store, smem, rows, cols);
}
} else {
LaunchSoftmaxBlockSMemImpl<FETCH, STORE, T, pack_size, block_size_conf_2>(
stream, fetch, store, smem, rows, cols);
}
} else {
LaunchSoftmaxBlockSMemImpl<FETCH, STORE, T, pack_size, block_size_conf_1>(stream, fetch, store,
smem, rows, cols);
Expand Down Expand Up @@ -908,6 +930,8 @@ inline bool TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t stream, FE
const int64_t rows, const int64_t cols) {
constexpr int block_size_conf_1 = 128;
constexpr int block_size_conf_2 = 256;
constexpr int block_size_conf_3 = 512;
constexpr int block_size_conf_4 = 1024;
const size_t smem = cols * sizeof(typename GetComputeType<T>::type) * 2;
int max_active_blocks_conf_1;
int max_active_blocks_conf_2;
Expand All @@ -921,8 +945,28 @@ inline bool TryDispatchSoftmaxGradBlockSMemImplBlockSize(cudaStream_t stream, FE
SoftmaxGradBlockSMemImpl<FETCH_Y, FETCH_DY, STORE, T, pack_size, block_size_conf_2>,
block_size_conf_2, smem));
if (max_active_blocks_conf_2 == max_active_blocks_conf_1) {
LaunchSoftmaxGradBlockSMemImpl<FETCH_Y, FETCH_DY, STORE, T, pack_size, block_size_conf_2>(
stream, fetch_y, fetch_dy, store, smem, rows, cols);
int max_active_blocks_conf_3;
OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_3,
SoftmaxGradBlockSMemImpl<FETCH_Y, FETCH_DY, STORE, T, pack_size, block_size_conf_3>,
block_size_conf_3, smem));
if (max_active_blocks_conf_3 == max_active_blocks_conf_2) {
int max_active_blocks_conf_4;
OF_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks_conf_4,
SoftmaxGradBlockSMemImpl<FETCH_Y, FETCH_DY, STORE, T, pack_size, block_size_conf_4>,
block_size_conf_4, smem));
if (max_active_blocks_conf_4 == max_active_blocks_conf_3) {
LaunchSoftmaxGradBlockSMemImpl<FETCH_Y, FETCH_DY, STORE, T, pack_size, block_size_conf_4>(
stream, fetch_y, fetch_dy, store, smem, rows, cols);
} else {
LaunchSoftmaxGradBlockSMemImpl<FETCH_Y, FETCH_DY, STORE, T, pack_size, block_size_conf_3>(
stream, fetch_y, fetch_dy, store, smem, rows, cols);
}
} else {
LaunchSoftmaxGradBlockSMemImpl<FETCH_Y, FETCH_DY, STORE, T, pack_size, block_size_conf_2>(
stream, fetch_y, fetch_dy, store, smem, rows, cols);
}
} else {
LaunchSoftmaxGradBlockSMemImpl<FETCH_Y, FETCH_DY, STORE, T, pack_size, block_size_conf_1>(
stream, fetch_y, fetch_dy, store, smem, rows, cols);
Expand Down

0 comments on commit 5f34b13

Please sign in to comment.