Skip to content

Commit

Permalink
[BYOC] CUTLASS integration (#9261)
Browse files Browse the repository at this point in the history
* byoc cutlass

* add cmake and fix build

* test worked but accuracy is bad

* fixed argument printing properly

* moving files

* moving contents of cutlass_profiler into python/tvm/contrib/cutlass

* run black

* remove irrelavant codegen code

* clang format

* tried replacing sm 75 with 80, didn't help improve accuracy

* remove irrelavant code from generator

* tried dense + bias fusion but generated cu file does not compile

* dense + bias worked after adding Leyuan's patch, bias + relu worked too

* tried adding sm80 generator but accuracy is still off

* remove GemmUniversal generator

* cleanup partition and build

* moved partition, profile and build function out of test

* turned out the result match's TVM non-cutlass result. Numpy fp16
matmul is busted?

* clean up test

* LinearCombination can be reused for bias only epilogue

* remove unsupported epilogues like gelu

* removing deadcode

* unify gemm templates for with or without beta scaling

* supported gelu but accuracy is slightly off

* gelu test passed with relaxed rtol

* cleanup

* remove unused stuff from library.py

* move profiler template into its own file

* removed gemm_profiler.py

* move contents of compile_engine.py into gen_gemm.py

* rename to profiler_template.cu to avoid CI issue

* cleaning up trying to pass pylint

* add missing asf header

* run black

* fixing many pylint issues except wildcard import

* fixed wildcard warning

* add missing CUTLASS.cmake file, restore gemm_profiler.py

* pylint

* minor fix

* add license

* start filling in TODO doc

* rename GemmProfiler to GemmProfilerEmitter

* more renaming and doc

* add doc to the main compile API

* refactored generator

* run black

* black fix

* finish doc TODO

* add test for 32 bit accum

* fixed kernel generator to correctly handle fp32 accum

* revise build-related API

* add option to profile only one kernel

* add option to enable parallel compilation

* clean up gen_gemm

* doc update

* profile_cutlass_kernels -> tune_cutlass_kernels

Co-authored-by: leyuan.wang <leyuan.wang@bytedance.com>
Co-authored-by: Masahiro Masuda <masahi129@gmail.com>
  • Loading branch information
3 people committed Oct 29, 2021
1 parent dae71db commit 541f9f2
Show file tree
Hide file tree
Showing 18 changed files with 1,864 additions and 6 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
[submodule "3rdparty/libbacktrace"]
path = 3rdparty/libbacktrace
url = https://github.com/tlc-pack/libbacktrace.git
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass
1 change: 1 addition & 0 deletions 3rdparty/cutlass
Submodule cutlass added at a3bcc6
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ tvm_option(USE_MKLDNN "Build with MKLDNN" OFF)
tvm_option(USE_DNNL_CODEGEN "Enable MKLDNN (DNNL) codegen" OFF)
tvm_option(USE_CUDNN "Build with cuDNN" OFF)
tvm_option(USE_CUBLAS "Build with cuBLAS" OFF)
tvm_option(USE_CUTLASS "Build with CUTLASS" OFF)
tvm_option(USE_THRUST "Build with Thrust" OFF)
tvm_option(USE_MIOPEN "Build with ROCM:MIOpen" OFF)
tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
Expand Down Expand Up @@ -428,6 +429,7 @@ include(cmake/modules/contrib/EthosU.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
include(cmake/modules/contrib/CUTLASS.cmake)
include(cmake/modules/contrib/ExampleTargetHooks.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Posit.cmake)
Expand Down
5 changes: 5 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -238,3 +238,8 @@ The Unlicense
-------------

3rdparty/rang

BSD 3-Clause "New" or "Revised" License
---------------------------------------

3rdparty/cutlass
23 changes: 23 additions & 0 deletions cmake/modules/contrib/CUTLASS.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(USE_CUTLASS)
file(GLOB CUTLASS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc)
list(APPEND COMPILER_SRCS ${CUTLASS_RELAY_CONTRIB_SRC})

message(STATUS "Build with CUTLASS")
endif()
23 changes: 23 additions & 0 deletions licenses/LICENSE.cutlass.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
Copyright (c) 2017 - 2020, NVIDIA CORPORATION. All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the NVIDIA CORPORATION nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
18 changes: 18 additions & 0 deletions python/tvm/contrib/cutlass/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""BYOC support for CUTLASS."""
from .build import tune_cutlass_kernels, build_cutlass_kernels
172 changes: 172 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Driver for partitioning and building a Relay module for CUTLASS offload."""
import tvm
from tvm import runtime, relay
from .gen_gemm import CutlassGemmProfiler


class GemmAnnotator(tvm.relay.ExprVisitor):
"""Annotates partitioned functions with shape and dtype information."""

def __init__(self):
super().__init__()
self.signature = {}

def visit_call(self, call):
op = call.op
if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs:
self.signature["op_type"] = op.attrs["Composite"]
for i, arg in enumerate(op.params):
self.signature["arg%d_shape" % i] = arg.checked_type.shape
self.signature["arg%d_dtype" % i] = arg.checked_type.dtype
self.signature["ret_shape"] = op.ret_type.shape
self.signature["ret_dtype"] = op.ret_type.dtype


def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"):
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which
kernels to emit.
Parameters
----------
mod : IRModule
The Relay module with cutlass partitions.
sm : int
An integer specifying the compute capability. For example, 75 for Turing and
80 or 86 for Ampere.
profile_all : bool
Whether or not profile all candidate kernels, or stop profiling after
the first applicable kernel is found.
use_multiprocessing : bool
Whether or not compile profiler executables for different kernels in parallel.
tmp_dir : string, optional
A temporary directory where intermediate compiled artifacts will be stored.
Returns
-------
mod : IRModule
The updated module annotated with cutlass profiling information.
num_cutlass_partition : int
The number of partitioned functions created for CUTLASS.
"""
cutlass_profiler = CutlassGemmProfiler(sm, "../../../3rdparty/cutlass", tmp_dir)
num_cutlass_partition = 0
for var in mod.get_global_vars():
fun_name = var.name_hint
func = mod[fun_name]
annotator = GemmAnnotator()
if "cutlass" in fun_name:
num_cutlass_partition += 1
annotator.visit(func)
# call cutlass profiler to find best settings, update attr
new_attrs = {}
new_attrs.update(annotator.signature)
for key in func.attrs.keys():
new_attrs[key] = func.attrs[key]
# call profiler
arg0_shape = new_attrs["arg0_shape"]
arg1_shape = new_attrs["arg1_shape"]
MM = arg0_shape[0]
KK = arg0_shape[1]
NN = arg1_shape[0]
out = cutlass_profiler.profile(
MM, NN, KK, annotator.signature["ret_dtype"], profile_all, use_multiprocessing
)
if new_attrs["op_type"] == "cutlass.dense":
new_attrs["cutlass_op_def"] = out["opdef"]
elif new_attrs["op_type"] == "cutlass.dense_bias":
new_attrs["cutlass_op_def"] = out["opdef_bias"]
elif new_attrs["op_type"] == "cutlass.dense_bias_relu":
new_attrs["cutlass_op_def"] = out["opdef_bias_relu"]
elif "cutlass.dense_bias_gelu" in new_attrs["op_type"]:
new_attrs["cutlass_op_def"] = out["opdef_bias_gelu"]
else:
raise ValueError("%s pattern is not implemented." % new_attrs["op_type"])
new_attrs["cutlass_op_name"] = out["name"]

print("The best kernel is " + new_attrs["cutlass_op_name"])
if new_attrs["cutlass_op_name"].find("_tn_align") > 0:
new_attrs["lda"] = "K"
new_attrs["ldb"] = "K"
new_attrs["ldc"] = "N"
elif new_attrs["cutlass_op_name"].find("_nt_align") > 0:
new_attrs["lda"] = "M"
new_attrs["ldb"] = "N"
new_attrs["ldc"] = "N"
else:
raise ValueError("%s unsupported operation" % new_attrs["cutlass_op_name"])
new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs)
new_func = relay.Function(
func.params,
func.body,
ret_type=func.ret_type,
type_params=func.type_params,
attrs=new_attrs,
)
mod.update_func(var, new_func)

return mod, num_cutlass_partition


def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so"):
"""Compile CUTLASS kernels in lib and return the runtime module ready to run.
Parameters
----------
lib : GraphExecutorFactoryModule
The output from relay.build containing compiled host code and non-cutlass kernels.
sm : int
An integer specifying the compute capability. For example, 75 for Turing and
80 or 86 for Ampere.
tmp_dir : string, optional
A temporary directory where intermediate compiled artifacts will be stored.
lib_path : string, optional
The path to a shared library which will be generated as the result of the build process
Returns
-------
updated_lib : runtime.Module
The updated module with compiled cutlass kernels.
"""
cutlass_path = "../../../3rdparty/cutlass/include"
cutlass_util_path = "../../../3rdparty/cutlass/tools/util/include"

kwargs = {}
kwargs["cc"] = "nvcc"
kwargs["options"] = [
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"-gencode=arch=compute_%d,code=[sm_%d,compute_%d]" % (sm, sm, sm),
"-Xcompiler=-fPIC",
"-Xcompiler=-Wconversion",
"-Xcompiler=-fno-strict-aliasing",
"-O3",
"-std=c++14",
"-I" + cutlass_path,
"-I" + cutlass_util_path,
]
lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs)
return runtime.load_module(lib_path)
Loading

0 comments on commit 541f9f2

Please sign in to comment.