Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions apex/contrib/csrc/groupbn/batch_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#include "c10/macros/Macros.h"

#include <ATen/cuda/CUDAContext.h>

#define VERBOSE_DEFAULT false

Expand Down Expand Up @@ -626,7 +626,7 @@ class NhwcBatchNorm {
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
Expand All @@ -635,7 +635,7 @@ class NhwcBatchNorm {
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
Expand Down
5 changes: 3 additions & 2 deletions apex/contrib/csrc/groupbn/batch_norm_add_relu.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "nhwc_batch_norm_kernel.h"
#include "cuda_utils.h"
#include "c10/macros/Macros.h"
#include <ATen/cuda/CUDAContext.h>

#ifdef USE_ROCM
using bitmask_t = uint64_t;
Expand Down Expand Up @@ -530,7 +531,7 @@ class NhwcBatchNormAddRelu {
// Calculate the expected fwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_fwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int fwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float);
int fwd_smem_bytes = SMEM_SIZE_FWD + fwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / fwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
Expand All @@ -539,7 +540,7 @@ class NhwcBatchNormAddRelu {
// Calculate the expected bwd kernel occupancy, as dictated by shared memory usage.
static int smem_driven_bwd_occupancy(int device_id, const int max_cta_per_sm) {
using namespace at::cuda::utils;
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/C10_WARP_SIZE)*ELEMENTS_PER_LDG*sizeof(float);
int bwd_reduction_bytes = THREADS_PER_PIXEL*(THREADS_PER_CTA/at::cuda::warp_size())*ELEMENTS_PER_LDG*sizeof(float);
int bwd_smem_bytes = SMEM_SIZE_BWD + bwd_reduction_bytes;
int occupancy = MaxSharedMemoryPerMultiprocessor(device_id) / bwd_smem_bytes;
return std::min(max_cta_per_sm, occupancy);
Expand Down
56 changes: 19 additions & 37 deletions apex/contrib/csrc/groupbn/nhwc_batch_norm_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,14 +546,8 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,
void* params_my_data, void** params_pair_datas, int off,
const int magic,
const int sync_iters) {
// The size of a warp.
#ifdef USE_ROCM
const int THREADS_PER_WARP = 64;
#else
const int THREADS_PER_WARP = 32;
#endif
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
const int WARPS_PER_CTA = THREADS_PER_CTA / C10_WARP_SIZE;
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 16;
// The number of elements per ldg.
Expand All @@ -564,13 +558,13 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,
const int MAX_BLOCK_Y = 256;
const int MAX_OFFSET = REDUCE_OPS*MAX_BLOCK_Y;
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
const int warp_id = threadIdx.x / C10_WARP_SIZE;
const int lane_id = threadIdx.x % C10_WARP_SIZE;
// total size of data per sync iter
const int data_total = MAX_OFFSET*THREADS_PER_PIXEL*ELEMENTS_PER_LDG*2;

#ifdef USE_ROCM
for (int offset = THREADS_PER_PIXEL; offset <= THREADS_PER_WARP >> 1; offset <<= 1) {
for (int offset = THREADS_PER_PIXEL; offset <= C10_WARP_SIZE >> 1; offset <<= 1) {
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
Expand Down Expand Up @@ -598,16 +592,16 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,

#pragma unroll
for (int offset = 1;
offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
offset < WARPS_PER_CTA/(C10_WARP_SIZE / THREADS_PER_PIXEL); ++offset) {
float y[ELEMENTS_PER_LDG];
// Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
read_from_smem(y, smem, threadIdx.x + offset*C10_WARP_SIZE);
// Compute the updated sum.
add(x, y);
}

#ifdef USE_ROCM
for (int offset = THREADS_PER_WARP >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
for (int offset = C10_WARP_SIZE >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) { for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
}
Expand Down Expand Up @@ -681,21 +675,15 @@ DEVICE_FUNCTION void parallel_sums_16x2(float *smem, float (&x)[4], int nhw,

template< int THREADS_PER_CTA >
DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {
// The size of a warp.
#ifdef USE_ROCM
const int THREADS_PER_WARP = 64;
#else
const int THREADS_PER_WARP = 32;
#endif
// The number of warps in a CTA.
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
const int WARPS_PER_CTA = THREADS_PER_CTA / C10_WARP_SIZE;
// The number of threads per pixel.
const int THREADS_PER_PIXEL = 8;
// The number of elements per ldg.
const int ELEMENTS_PER_LDG = 4;
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
const int warp_id = threadIdx.x / C10_WARP_SIZE;
const int lane_id = threadIdx.x % C10_WARP_SIZE;

#pragma unroll
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
Expand All @@ -718,10 +706,10 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {

#pragma unroll
for (int offset = 1;
offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
offset < WARPS_PER_CTA/(C10_WARP_SIZE / THREADS_PER_PIXEL); ++offset) {
float y[ELEMENTS_PER_LDG];
// Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
read_from_smem(y, smem, threadIdx.x + offset*C10_WARP_SIZE);
// Compute the updated sum.
add(x, y);
}
Expand All @@ -745,20 +733,14 @@ DEVICE_FUNCTION void parallel_sums_8x4(float *smem, float (&x)[4], int nhw) {

template< int THREADS_PER_CTA, int THREADS_PER_PIXEL, int ELEMENTS_PER_LDG >
DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], int nhw) {
// The size of a warp.
#ifdef USE_ROCM
const int THREADS_PER_WARP = 64;
#else
const int THREADS_PER_WARP = 32;
#endif
const int WARPS_PER_CTA = THREADS_PER_CTA / THREADS_PER_WARP;
const int WARPS_PER_CTA = THREADS_PER_CTA / C10_WARP_SIZE;
// The warp decomposition.
const int warp_id = threadIdx.x / THREADS_PER_WARP;
const int lane_id = threadIdx.x % THREADS_PER_WARP;
const int warp_id = threadIdx.x / C10_WARP_SIZE;
const int lane_id = threadIdx.x % C10_WARP_SIZE;
// total size of data per sync iter

#ifdef USE_ROCM
for (int offset = THREADS_PER_PIXEL; offset <= THREADS_PER_WARP >> 1; offset <<= 1) {
for (int offset = THREADS_PER_PIXEL; offset <= C10_WARP_SIZE >> 1; offset <<= 1) {
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
Expand Down Expand Up @@ -786,16 +768,16 @@ DEVICE_FUNCTION void parallel_sums(float *smem, float (&x)[ELEMENTS_PER_LDG], in

#pragma unroll
for (int offset = 1;
offset < WARPS_PER_CTA/(THREADS_PER_WARP / THREADS_PER_PIXEL); ++offset) {
offset < WARPS_PER_CTA/(C10_WARP_SIZE / THREADS_PER_PIXEL); ++offset) {
float y[ELEMENTS_PER_LDG];
// Read the mean and variance from the other pixel.
read_from_smem(y, smem, threadIdx.x + offset*THREADS_PER_WARP);
read_from_smem(y, smem, threadIdx.x + offset*C10_WARP_SIZE);
// Compute the updated sum.
add(x, y);
}

#ifdef USE_ROCM
for (int offset = THREADS_PER_WARP >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) {
for (int offset = C10_WARP_SIZE >> 1; offset >= THREADS_PER_PIXEL; offset >>= 1) {
for (int i = 0; i < ELEMENTS_PER_LDG; ++i) {
x[i] += shfl_sync(x[i], offset + lane_id);
}
Expand Down
27 changes: 15 additions & 12 deletions apex/contrib/csrc/multihead_attn/softmax.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <stdint.h>
#include <cuda_fp16.h>
#include <cmath>
#include <ATen/cuda/CUDAContext.h>

#ifdef USE_ROCM
#define APEX_WARP_SHFL_XOR(mask, value, offset, width) __shfl_xor(value, offset, width)
Expand Down Expand Up @@ -235,7 +236,7 @@ bool warp_softmax_kernel(int log2_elements, int &warp_size,
softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down Expand Up @@ -654,7 +655,7 @@ bool warp_additive_masked_softmax_dropout_kernel(
&kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down Expand Up @@ -948,7 +949,7 @@ bool warp_additive_masked_softmax_kernel(
additive_masked_softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down Expand Up @@ -1240,7 +1241,7 @@ bool warp_masked_softmax_kernel(
masked_softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down Expand Up @@ -1488,7 +1489,7 @@ bool warp_time_masked_softmax_kernel(
time_masked_softmax_forward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down Expand Up @@ -1741,7 +1742,7 @@ void dispatch_masked_scale_softmax_backward_masked_out(
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
(next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
Expand Down Expand Up @@ -1855,7 +1856,8 @@ void dispatch_masked_scale_softmax_backward_masked_out_stream(
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
(next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down Expand Up @@ -2254,7 +2256,7 @@ bool masked_scale_softmax_warp_backward_recompute_kernel(
is_log_softmax> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down Expand Up @@ -2392,7 +2394,8 @@ void dispatch_masked_scale_softmax_backward_stream(
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
(next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down Expand Up @@ -2593,7 +2596,7 @@ void dispatch_softmax_backward_fused_native(
// This value must match the WARP_SIZE constexpr value computed inside
// softmax_warp_backward.
int warp_size =
(next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
(next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// This value must match the WARP_BATCH constexpr value computed inside
// softmax_warp_backward.
Expand Down Expand Up @@ -2805,7 +2808,7 @@ bool warp_softmax_backward_kernel(
softmax_backward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down Expand Up @@ -3048,7 +3051,7 @@ bool warp_masked_softmax_backward_kernel(
masked_softmax_backward_func<input_t, output_t> &kernel) {
// determine size of a warp
const int next_power_of_two = 1 << log2_elements;
warp_size = (next_power_of_two < 32) ? next_power_of_two : 32;
warp_size = (next_power_of_two < at::cuda::warp_size()) ? next_power_of_two : at::cuda::warp_size();

// determine how many batches a warp should process.
batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;
Expand Down
14 changes: 7 additions & 7 deletions apex/contrib/csrc/transducer/transducer_joint_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -729,8 +729,8 @@ std::vector<torch::Tensor> transducer_joint_cuda_forward(

TORCH_CHECK(opt == 0 or opt == 1, "Got an invalid optimization level ", opt);
// Simple heuristics
const int numThread = std::min(128, (static_cast<int>(hiddenSize)+C10_WARP_SIZE-1)
/ C10_WARP_SIZE * C10_WARP_SIZE);
const int numThread = std::min(128, (static_cast<int>(hiddenSize)+at::cuda::warp_size()-1)
/ at::cuda::warp_size() * at::cuda::warp_size());

if (opt == 0){
// vanilla kernel
Expand Down Expand Up @@ -862,7 +862,7 @@ std::vector<torch::Tensor> transducer_joint_cuda_backward(
const int hiddenSize = grad.size(-1);

const auto deviceProperties = at::cuda::getCurrentDeviceProperties();
const int maxNumWarp = deviceProperties->maxThreadsPerBlock / C10_WARP_SIZE;
const int maxNumWarp = deviceProperties->maxThreadsPerBlock / at::cuda::warp_size();

torch::Tensor fGrad = torch::empty({batchSize, maxFLen, hiddenSize}, tensorOpt);
torch::Tensor gGrad = torch::empty({batchSize, maxGLen, hiddenSize}, tensorOpt);
Expand All @@ -880,8 +880,8 @@ std::vector<torch::Tensor> transducer_joint_cuda_backward(

// Need smem for transposing the partial sum. The partial sum is in a matrix of the shape
// numWarp x warpSize
const int smemSize = numWarp * C10_WARP_SIZE;
const dim3 threads(C10_WARP_SIZE, numWarp, 1);
const int smemSize = numWarp * at::cuda::warp_size();
const dim3 threads(at::cuda::warp_size(), numWarp, 1);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(dtype, "transducer_joint_cuda_backward_kernel", ([&] {
auto gradPtr = grad.data_ptr<scalar_t>();
Expand All @@ -905,7 +905,7 @@ std::vector<torch::Tensor> transducer_joint_cuda_backward(
if (vectFactor > 1 and hiddenSize%vectFactor == 0 and memAlign){
// If vectorization helps and the alignment requirement is met, use the vectorized
// kernel. For simplicity, hiddenSize needs to be a multiple vecFactor.
const dim3 blocks( (hiddenSize+C10_WARP_SIZE*vectFactor-1)/(C10_WARP_SIZE*vectFactor),
const dim3 blocks( (hiddenSize+at::cuda::warp_size()*vectFactor-1)/(at::cuda::warp_size()*vectFactor),
maxFLen+maxGLen,
batchSize);
if (masked){
Expand Down Expand Up @@ -944,7 +944,7 @@ std::vector<torch::Tensor> transducer_joint_cuda_backward(
}
}
else{
const dim3 blocks((hiddenSize+C10_WARP_SIZE-1)/C10_WARP_SIZE,
const dim3 blocks((hiddenSize+at::cuda::warp_size()-1)/at::cuda::warp_size(),
maxFLen + maxGLen, batchSize);
if (masked){
transducer_joint_combined_backward<scalar_t, acc_t, OffsetCalBwd, true>
Expand Down
Loading