Skip to content

Commit 541f9f2

Browse files
Laurawlyleyuan.wangmasahi
authored
[BYOC] CUTLASS integration (#9261)
* byoc cutlass * add cmake and fix build * test worked but accuracy is bad * fixed argument printing properly * moving files * moving contents of cutlass_profiler into python/tvm/contrib/cutlass * run black * remove irrelavant codegen code * clang format * tried replacing sm 75 with 80, didn't help improve accuracy * remove irrelavant code from generator * tried dense + bias fusion but generated cu file does not compile * dense + bias worked after adding Leyuan's patch, bias + relu worked too * tried adding sm80 generator but accuracy is still off * remove GemmUniversal generator * cleanup partition and build * moved partition, profile and build function out of test * turned out the result match's TVM non-cutlass result. Numpy fp16 matmul is busted? * clean up test * LinearCombination can be reused for bias only epilogue * remove unsupported epilogues like gelu * removing deadcode * unify gemm templates for with or without beta scaling * supported gelu but accuracy is slightly off * gelu test passed with relaxed rtol * cleanup * remove unused stuff from library.py * move profiler template into its own file * removed gemm_profiler.py * move contents of compile_engine.py into gen_gemm.py * rename to profiler_template.cu to avoid CI issue * cleaning up trying to pass pylint * add missing asf header * run black * fixing many pylint issues except wildcard import * fixed wildcard warning * add missing CUTLASS.cmake file, restore gemm_profiler.py * pylint * minor fix * add license * start filling in TODO doc * rename GemmProfiler to GemmProfilerEmitter * more renaming and doc * add doc to the main compile API * refactored generator * run black * black fix * finish doc TODO * add test for 32 bit accum * fixed kernel generator to correctly handle fp32 accum * revise build-related API * add option to profile only one kernel * add option to enable parallel compilation * clean up gen_gemm * doc update * profile_cutlass_kernels -> tune_cutlass_kernels Co-authored-by: leyuan.wang <leyuan.wang@bytedance.com> Co-authored-by: Masahiro Masuda <masahi129@gmail.com>
1 parent dae71db commit 541f9f2

File tree

18 files changed

+1864
-6
lines changed

18 files changed

+1864
-6
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@
1313
[submodule "3rdparty/libbacktrace"]
1414
path = 3rdparty/libbacktrace
1515
url = https://github.com/tlc-pack/libbacktrace.git
16+
[submodule "3rdparty/cutlass"]
17+
path = 3rdparty/cutlass
18+
url = https://github.com/NVIDIA/cutlass

3rdparty/cutlass

Submodule cutlass added at a3bcc69

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
6969
tvm_option(USE_DNNL_CODEGEN "Enable MKLDNN (DNNL) codegen" OFF)
7070
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
7171
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
72+
tvm_option(USE_CUTLASS "Build with CUTLASS" OFF)
7273
tvm_option(USE_THRUST "Build with Thrust" OFF)
7374
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
7475
tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
@@ -428,6 +429,7 @@ include(cmake/modules/contrib/EthosU.cmake)
428429
include(cmake/modules/contrib/BLAS.cmake)
429430
include(cmake/modules/contrib/CODEGENC.cmake)
430431
include(cmake/modules/contrib/DNNL.cmake)
432+
include(cmake/modules/contrib/CUTLASS.cmake)
431433
include(cmake/modules/contrib/ExampleTargetHooks.cmake)
432434
include(cmake/modules/contrib/Random.cmake)
433435
include(cmake/modules/contrib/Posit.cmake)

LICENSE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,8 @@ The Unlicense
238238
-------------
239239

240240
3rdparty/rang
241+
242+
BSD 3-Clause "New" or "Revised" License
243+
---------------------------------------
244+
245+
3rdparty/cutlass
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
if(USE_CUTLASS)
19+
file(GLOB CUTLASS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc)
20+
list(APPEND COMPILER_SRCS ${CUTLASS_RELAY_CONTRIB_SRC})
21+
22+
message(STATUS "Build with CUTLASS")
23+
endif()

licenses/LICENSE.cutlass.txt

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved.
2+
3+
Redistribution and use in source and binary forms, with or without
4+
modification, are permitted provided that the following conditions are met:
5+
* Redistributions of source code must retain the above copyright
6+
notice, this list of conditions and the following disclaimer.
7+
* Redistributions in binary form must reproduce the above copyright
8+
notice, this list of conditions and the following disclaimer in the
9+
documentation and/or other materials provided with the distribution.
10+
* Neither the name of the NVIDIA CORPORATION nor the
11+
names of its contributors may be used to endorse or promote products
12+
derived from this software without specific prior written permission.
13+
14+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
15+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
16+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
17+
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
18+
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
19+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
20+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
21+
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
22+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
23+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""BYOC support for CUTLASS."""
18+
from .build import tune_cutlass_kernels, build_cutlass_kernels
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name
18+
"""Driver for partitioning and building a Relay module for CUTLASS offload."""
19+
import tvm
20+
from tvm import runtime, relay
21+
from .gen_gemm import CutlassGemmProfiler
22+
23+
24+
class GemmAnnotator(tvm.relay.ExprVisitor):
25+
"""Annotates partitioned functions with shape and dtype information."""
26+
27+
def __init__(self):
28+
super().__init__()
29+
self.signature = {}
30+
31+
def visit_call(self, call):
32+
op = call.op
33+
if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs:
34+
self.signature["op_type"] = op.attrs["Composite"]
35+
for i, arg in enumerate(op.params):
36+
self.signature["arg%d_shape" % i] = arg.checked_type.shape
37+
self.signature["arg%d_dtype" % i] = arg.checked_type.dtype
38+
self.signature["ret_shape"] = op.ret_type.shape
39+
self.signature["ret_dtype"] = op.ret_type.dtype
40+
41+
42+
def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"):
43+
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which
44+
kernels to emit.
45+
46+
Parameters
47+
----------
48+
mod : IRModule
49+
The Relay module with cutlass partitions.
50+
51+
sm : int
52+
An integer specifying the compute capability. For example, 75 for Turing and
53+
80 or 86 for Ampere.
54+
55+
profile_all : bool
56+
Whether or not profile all candidate kernels, or stop profiling after
57+
the first applicable kernel is found.
58+
59+
use_multiprocessing : bool
60+
Whether or not compile profiler executables for different kernels in parallel.
61+
62+
tmp_dir : string, optional
63+
A temporary directory where intermediate compiled artifacts will be stored.
64+
65+
Returns
66+
-------
67+
mod : IRModule
68+
The updated module annotated with cutlass profiling information.
69+
70+
num_cutlass_partition : int
71+
The number of partitioned functions created for CUTLASS.
72+
"""
73+
cutlass_profiler = CutlassGemmProfiler(sm, "../../../3rdparty/cutlass", tmp_dir)
74+
num_cutlass_partition = 0
75+
for var in mod.get_global_vars():
76+
fun_name = var.name_hint
77+
func = mod[fun_name]
78+
annotator = GemmAnnotator()
79+
if "cutlass" in fun_name:
80+
num_cutlass_partition += 1
81+
annotator.visit(func)
82+
# call cutlass profiler to find best settings, update attr
83+
new_attrs = {}
84+
new_attrs.update(annotator.signature)
85+
for key in func.attrs.keys():
86+
new_attrs[key] = func.attrs[key]
87+
# call profiler
88+
arg0_shape = new_attrs["arg0_shape"]
89+
arg1_shape = new_attrs["arg1_shape"]
90+
MM = arg0_shape[0]
91+
KK = arg0_shape[1]
92+
NN = arg1_shape[0]
93+
out = cutlass_profiler.profile(
94+
MM, NN, KK, annotator.signature["ret_dtype"], profile_all, use_multiprocessing
95+
)
96+
if new_attrs["op_type"] == "cutlass.dense":
97+
new_attrs["cutlass_op_def"] = out["opdef"]
98+
elif new_attrs["op_type"] == "cutlass.dense_bias":
99+
new_attrs["cutlass_op_def"] = out["opdef_bias"]
100+
elif new_attrs["op_type"] == "cutlass.dense_bias_relu":
101+
new_attrs["cutlass_op_def"] = out["opdef_bias_relu"]
102+
elif "cutlass.dense_bias_gelu" in new_attrs["op_type"]:
103+
new_attrs["cutlass_op_def"] = out["opdef_bias_gelu"]
104+
else:
105+
raise ValueError("%s pattern is not implemented." % new_attrs["op_type"])
106+
new_attrs["cutlass_op_name"] = out["name"]
107+
108+
print("The best kernel is " + new_attrs["cutlass_op_name"])
109+
if new_attrs["cutlass_op_name"].find("_tn_align") > 0:
110+
new_attrs["lda"] = "K"
111+
new_attrs["ldb"] = "K"
112+
new_attrs["ldc"] = "N"
113+
elif new_attrs["cutlass_op_name"].find("_nt_align") > 0:
114+
new_attrs["lda"] = "M"
115+
new_attrs["ldb"] = "N"
116+
new_attrs["ldc"] = "N"
117+
else:
118+
raise ValueError("%s unsupported operation" % new_attrs["cutlass_op_name"])
119+
new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs)
120+
new_func = relay.Function(
121+
func.params,
122+
func.body,
123+
ret_type=func.ret_type,
124+
type_params=func.type_params,
125+
attrs=new_attrs,
126+
)
127+
mod.update_func(var, new_func)
128+
129+
return mod, num_cutlass_partition
130+
131+
132+
def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so"):
133+
"""Compile CUTLASS kernels in lib and return the runtime module ready to run.
134+
135+
Parameters
136+
----------
137+
lib : GraphExecutorFactoryModule
138+
The output from relay.build containing compiled host code and non-cutlass kernels.
139+
140+
sm : int
141+
An integer specifying the compute capability. For example, 75 for Turing and
142+
80 or 86 for Ampere.
143+
144+
tmp_dir : string, optional
145+
A temporary directory where intermediate compiled artifacts will be stored.
146+
147+
lib_path : string, optional
148+
The path to a shared library which will be generated as the result of the build process
149+
150+
Returns
151+
-------
152+
updated_lib : runtime.Module
153+
The updated module with compiled cutlass kernels.
154+
"""
155+
cutlass_path = "../../../3rdparty/cutlass/include"
156+
cutlass_util_path = "../../../3rdparty/cutlass/tools/util/include"
157+
158+
kwargs = {}
159+
kwargs["cc"] = "nvcc"
160+
kwargs["options"] = [
161+
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
162+
"-gencode=arch=compute_%d,code=[sm_%d,compute_%d]" % (sm, sm, sm),
163+
"-Xcompiler=-fPIC",
164+
"-Xcompiler=-Wconversion",
165+
"-Xcompiler=-fno-strict-aliasing",
166+
"-O3",
167+
"-std=c++14",
168+
"-I" + cutlass_path,
169+
"-I" + cutlass_util_path,
170+
]
171+
lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs)
172+
return runtime.load_module(lib_path)

0 commit comments

Comments
 (0)