Skip to content

Commit

Permalink
[Optimization] Warp level reduction support for CUDA (#5498)
Browse files Browse the repository at this point in the history
- Added the warp level reduction support

- Upgraded shfl intrinsics to the sync version.

- This is the building block for scheduling softmax like operations.

Signed-off-by: Wei Pan <weip@nvidia.com>
  • Loading branch information
wpan11nv committed May 9, 2020
1 parent ae7e0a1 commit 64c6795
Show file tree
Hide file tree
Showing 9 changed files with 462 additions and 67 deletions.
27 changes: 24 additions & 3 deletions include/tvm/tir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -1234,22 +1234,43 @@ constexpr const char *tvm_call_trace_packed_lowered =
* }
*/
constexpr const char* tvm_storage_sync = "tvm_storage_sync";

/*!
* \brief See pseudo code
*
* Type tvm_warp_shuffle(Type value, warp_id, width, warp_size) {
* return (value passed in by warp indicated by warp_id);
* Type tvm_warp_shuffle(mask, Type value, warp_id, width, warp_size) {
* return (value passed in by warp indicated by this_warp_id);
* }
*
* Type tvm_warp_shuffle_up(mask, Type value, offset, width, warp_size) {
* return (value passed in by warp indicated by this_warp_id - offset);
* }
*
* Type tvm_warp_shuffle_down(mask, Type value, offset, width, warp_size) {
* return (value passed in by warp indicated by this_warp_id + offset);
* }
*
* unsigned tvm_warp_activemask() {
* return (32-bit mask of currently active threads in the calling warp);
* }
*
* Parameter warp_id indicates the source thread ID in a warp.
*
* Parameter offset indicates the relative distance to this_warp_id.
*
* Parameter width indicates the number of threads involved in one
* shuffle. See CUDA document for __shfl.
* shuffle. See CUDA document for __shfl_sync, __shfl_up_sync,
* __shfl_down_sync and __activemask.
*
* Parameter warp_size is the size of a warp, which helps a backend
* to determine wheter the width paramter is legal.
*
*/
constexpr const char* tvm_warp_shuffle = "tvm_warp_shuffle";
constexpr const char* tvm_warp_shuffle_up = "tvm_warp_shuffle_up";
constexpr const char* tvm_warp_shuffle_down = "tvm_warp_shuffle_down";
constexpr const char* tvm_warp_activemask = "tvm_warp_activemask";

/*!
* \brief Initialize the global barrier.
* Call this at beginning of kernel that need global barrier.
Expand Down
19 changes: 18 additions & 1 deletion src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ std::string CodeGenCUDA::Finish() {
decl_stream << _cuda_half_util;
}

if (enable_warp_shuffle_) {
decl_stream << _cuda_warp_intrinsic_util;
}

if (enable_int8_) {
decl_stream << "#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)\n";
decl_stream << "#include <sm_61_intrinsics.h>\n";
Expand Down Expand Up @@ -269,6 +273,11 @@ void CodeGenCUDA::PrintVecBinaryOp(

void CodeGenCUDA::PrintVecElemLoad(
const std::string& vec, DataType t, int i, std::ostream& os) { // NOLINT(*)
if (t.is_scalar()) {
os << vec;
return;
}

static const char access[] = {'x', 'y', 'z', 'w'};
CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
if ((t.is_int()) && t.bits() == 8) {
Expand Down Expand Up @@ -395,7 +404,15 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) {
os << sret;
}

void CodeGenCUDA::VisitExpr_(const CallNode *op, std::ostream& os) {
void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) {
// This is only for backward compatibility with __shfl_{up/down}.
// A macro will be used to replace *_sync calls to legacy ones.
if (op->is_intrinsic("__shfl_sync") ||
op->is_intrinsic("__shfl_up_sync") ||
op->is_intrinsic("__shfl_down_sync")) {
enable_warp_shuffle_ = true;
}

if (op->is_intrinsic(intrinsic::tvm_fill_fragment)) {
need_mma_h_ = true;
CHECK_EQ(op->args.size(), 6U);
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ class CodeGenCUDA final : public CodeGenC {
bool enable_fp16_{false};
// whether enable int8
bool enable_int8_{false};
// whether enable warp shuffle intrinsics
bool enable_warp_shuffle_{false};
// whether need math_constants.h
bool need_math_constants_h_{false};
// whether need mma.h
Expand Down
37 changes: 33 additions & 4 deletions src/target/source/intrin_rule_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,34 @@ struct CUDAPopcount {
}
};


struct CUDAWarpIntrinsic {
const char* operator()(DataType t, const std::string& name) const {
if (name == intrinsic::tvm_warp_shuffle) {
return "__shfl_sync";
}
if (name == intrinsic::tvm_warp_shuffle_up) {
return "__shfl_up_sync";
}
if (name == intrinsic::tvm_warp_shuffle_down) {
return "__shfl_down_sync";
}
if (name == intrinsic::tvm_warp_activemask) {
return "__activemask";
}
return "";
}
};

template <typename T>
static void DispatchCUDAShuffle(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2]}};
*rv = CallNode::make(
call->dtype, "__shfl", cuda_args, CallNode::PureExtern);
const char* name = T()(call->dtype, call->name);
*rv = CallNode::make(call->dtype, name, cuda_args, CallNode::PureExtern);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.floor")
Expand Down Expand Up @@ -158,7 +178,16 @@ TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.popcount")
.set_body(DispatchExtern<CUDAPopcount>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle")
.set_body(DispatchCUDAShuffle);
.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_up")
.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_shuffle_down")
.set_body(DispatchCUDAShuffle<CUDAWarpIntrinsic>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.tvm_warp_activemask")
.set_body(DispatchExtern<CUDAWarpIntrinsic>);

TVM_REGISTER_GLOBAL("tvm.intrin.rule.cuda.fmod")
.set_body(DispatchExtern<CUDAMath>);
Expand Down
10 changes: 5 additions & 5 deletions src/target/source/intrin_rule_opencl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,13 @@ static void DispatchIntelShuffle(const TVMArgs& args, TVMRetValue* rv) {
PrimExpr e = args[0];
const CallNode* call = e.as<CallNode>();
CHECK(call != nullptr);
CHECK_EQ(call->args.size(), 4); // value, warp_id, width, warp_size
CHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
arith::Analyzer analyzer;
CHECK(analyzer.CanProve(call->args[2] == call->args[3]))
CHECK(analyzer.CanProve(call->args[3] == call->args[4]))
<< "Intel warp shuffle dose not support width != warp_size";
Array<PrimExpr> cuda_args{{call->args[0], call->args[1]}};
*rv = CallNode::make(
call->dtype, "intel_sub_group_shuffle", cuda_args, CallNode::PureExtern);
Array<PrimExpr> opencl_args{{call->args[1], call->args[2]}};
*rv = CallNode::make(call->dtype, "intel_sub_group_shuffle",
opencl_args, CallNode::PureExtern);
}

TVM_REGISTER_GLOBAL("tvm.intrin.rule.opencl.tvm_warp_shuffle")
Expand Down
14 changes: 14 additions & 0 deletions src/target/source/literal/cuda_half_t.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,18 @@ __pack_half2(const half x, const half y) {
}
)";

static constexpr const char* _cuda_warp_intrinsic_util = R"(
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700)
#define __shfl_sync(mask, var, lane, width) \
__shfl((var), (lane), (width))
#define __shfl_down_sync(mask, var, offset, width) \
__shfl_down((var), (offset), (width))
#define __shfl_up_sync(mask, var, offset, width) \
__shfl_up((var), (offset), (width))
#endif
)";

#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_

0 comments on commit 64c6795

Please sign in to comment.