Skip to content

Commit

Permalink
[Perf] Enhance cudnn and cublas backend and enable TensorCore (#4353)
Browse files Browse the repository at this point in the history
* add half and mix precision support to cublas backend

* add TensorCore support in CuDNN

* enhance CuDNN support

* address comments and fix lint

* fix

* add fp16 test
  • Loading branch information
Hzfengsy authored and Laurawly committed Nov 25, 2019
1 parent fbb2a35 commit dabde40
Show file tree
Hide file tree
Showing 11 changed files with 436 additions and 96 deletions.
8 changes: 8 additions & 0 deletions include/tvm/runtime/util.h
Expand Up @@ -39,6 +39,14 @@ namespace runtime {
inline bool TypeMatch(TVMType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes;
}
/*!
* \brief Check whether two types are equal .
* \param lhs The left operand.
* \param rhs The right operand.
*/
inline bool TypeEqual(TVMType lhs, TVMType rhs) {
return lhs.code == rhs.code && lhs.bits == rhs.bits && lhs.lanes == rhs.lanes;
}
} // namespace runtime
} // namespace tvm
// Forward declare the intrinsic id we need
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/contrib/cublas.py
Expand Up @@ -20,7 +20,7 @@
from .. import api as _api
from .. import intrin as _intrin

def matmul(lhs, rhs, transa=False, transb=False):
def matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute matrix mult of A and rhs with cuBLAS
Parameters
Expand All @@ -41,13 +41,14 @@ def matmul(lhs, rhs, transa=False, transb=False):
"""
n = lhs.shape[1] if transa else lhs.shape[0]
m = rhs.shape[0] if transb else rhs.shape[1]
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
(n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")

def batch_matmul(lhs, rhs, transa=False, transb=False):
def batch_matmul(lhs, rhs, transa=False, transb=False, dtype=None):
"""Create an extern op that compute batch matrix mult of A and rhs with cuBLAS
Parameters
Expand All @@ -69,8 +70,9 @@ def batch_matmul(lhs, rhs, transa=False, transb=False):
b = lhs.shape[0]
n = lhs.shape[2] if transa else lhs.shape[1]
m = rhs.shape[1] if transb else rhs.shape[2]
dtype = dtype if dtype is not None else lhs.dtype
return _api.extern(
(b, n, m), [lhs, rhs],
lambda ins, outs: _intrin.call_packed(
"tvm.contrib.cublas.batch_matmul",
ins[0], ins[1], outs[0], transa, transb), name="C")
ins[0], ins[1], outs[0], transa, transb), dtype=dtype, name="C")
65 changes: 47 additions & 18 deletions python/tvm/contrib/cudnn.py
Expand Up @@ -22,7 +22,6 @@
from .. import intrin as _intrin
from .. import get_global_func as _get_global_func


# algos can be read from cudnn.h
_FWD_ALGOS = [
"CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM",
Expand Down Expand Up @@ -67,6 +66,7 @@
"bwd_data"
]


def algo_to_index(algo_type, algo_name):
"""Return a index represents the algorithm, which can be used in
calling CuDNN function
Expand Down Expand Up @@ -172,6 +172,7 @@ def conv2d_w_shape(in_channel,
"""
return [out_channel, in_channel, filter_h, filter_w]


def conv2d_output_shape(tensor_format,
pad_h,
pad_w,
Expand All @@ -180,7 +181,9 @@ def conv2d_output_shape(tensor_format,
dilation_h,
dilation_w,
x_shape,
w_shape):
w_shape,
data_dtype,
conv_dtype):
"""Get output shape of 2D convolution
Paramters
Expand Down Expand Up @@ -232,7 +235,9 @@ def conv2d_output_shape(tensor_format,
w_shape[1].value,
w_shape[2].value,
w_shape[3].value,
_get_np_int32_array_handle(oshape))
_get_np_int32_array_handle(oshape),
data_dtype,
conv_dtype)
return list(oshape)


Expand All @@ -245,7 +250,9 @@ def conv2d_find_algo(tensor_format,
dilation_w,
x_shape,
w_shape,
y_shape):
y_shape,
data_dtype,
conv_dtype):
"""Choose the best algo for the given input.
Paramters
Expand All @@ -272,6 +279,10 @@ def conv2d_find_algo(tensor_format,
weight shape
y_shape: list
output shape
data_dtype: str
data type
conv_dtype: str
convolution type
Returns
-------
Expand All @@ -297,7 +308,9 @@ def conv2d_find_algo(tensor_format,
int(y_shape[0]),
int(y_shape[1]),
int(y_shape[2]),
int(y_shape[3]))
int(y_shape[3]),
data_dtype,
conv_dtype)


def conv2d_forward(x,
Expand All @@ -310,7 +323,8 @@ def conv2d_forward(x,
dilation_w=1,
conv_mode=1,
tensor_format=0,
algo=-1):
algo=-1,
conv_dtype=None):
"""Create an extern op that compute 2D convolution with CuDNN
Parameters
Expand Down Expand Up @@ -341,12 +355,16 @@ def conv2d_forward(x,
algo: int
Forward algorithm, get index from ```algo_to_index``` function
if algo == -1, the best algo will be chosen by CUDNN
conv_dtype: str
convolution type
Returns
-------
y: Tensor
The result tensor
"""
conv_dtype = x.dtype if conv_dtype is None else conv_dtype

oshape = conv2d_output_shape(tensor_format,
pad_h,
pad_w,
Expand All @@ -355,18 +373,28 @@ def conv2d_forward(x,
dilation_h,
dilation_w,
list(x.shape),
list(w.shape))
list(w.shape),
x.dtype,
conv_dtype)
if algo == -1:
algo = conv2d_find_algo(tensor_format,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
list(x.shape),
list(w.shape),
oshape)
# For now if we try to call `cudnnFindConvolutionForwardAlgorithm` when
# using INT8 data type, CuDNN will crash down.
# On the other hand, CuDNN only support IMPLICIT_​PRECOMP_GEMM at NHWC format
if tensor_format == 1 and conv_dtype == "int32":
algo = 1
else:
algo = conv2d_find_algo(tensor_format,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
list(x.shape),
list(w.shape),
oshape,
x.dtype,
conv_dtype)

return _api.extern(
oshape, [x, w],
Expand All @@ -383,4 +411,5 @@ def conv2d_forward(x,
dilation_w,
ins[0],
ins[1],
outs[0]), name="y")
outs[0],
conv_dtype), name="y")
9 changes: 5 additions & 4 deletions src/runtime/contrib/cblas/gemm_common.h
Expand Up @@ -93,13 +93,13 @@ inline void CallGemm(TVMArgs args, TVMRetValue *ret, TGemmOp op) {
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
op(transb, transa, ColumnCount(B, transb), RowCount(A, transa),
ColumnCount(A, transa), static_cast<float>(alpha),
ColumnCount(A, transa), static_cast<typename TGemmOp::TDatatype>(alpha),
reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(B->data) + B->byte_offset),
ColumnStride(B),
reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(A->data) + A->byte_offset),
ColumnStride(A), static_cast<float>(beta),
ColumnStride(A), static_cast<typename TGemmOp::TDatatype>(beta),
reinterpret_cast<typename TGemmOp::TDatatype *>(
static_cast<char *>(C->data) + C->byte_offset),
ColumnStride(C));
Expand Down Expand Up @@ -170,9 +170,10 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue *ret, TBatchGemmOp op) {
DType *C_data = reinterpret_cast<typename TBatchGemmOp::TDatatype *>(
static_cast<char *>(C->data) + C->byte_offset);
op(batch_size, transb, transa, ColumnCount3D(B, transb),
RowCount3D(A, transa), ColumnCount3D(A, transa), static_cast<float>(alpha),
RowCount3D(A, transa), ColumnCount3D(A, transa),
static_cast<typename TBatchGemmOp::TDatatype>(alpha),
B_data, B_size, ColumnStride3D(B), A_data, A_size, ColumnStride3D(A),
static_cast<float>(beta), C_data, C_size, ColumnStride3D(C));
static_cast<typename TBatchGemmOp::TDatatype>(beta), C_data, C_size, ColumnStride3D(C));
}

} // namespace contrib
Expand Down

0 comments on commit dabde40

Please sign in to comment.