|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +# pylint: disable=invalid-name |
| 18 | +"""Driver for partitioning and building a Relay module for CUTLASS offload.""" |
| 19 | +import tvm |
| 20 | +from tvm import runtime, relay |
| 21 | +from .gen_gemm import CutlassGemmProfiler |
| 22 | + |
| 23 | + |
| 24 | +class GemmAnnotator(tvm.relay.ExprVisitor): |
| 25 | + """Annotates partitioned functions with shape and dtype information.""" |
| 26 | + |
| 27 | + def __init__(self): |
| 28 | + super().__init__() |
| 29 | + self.signature = {} |
| 30 | + |
| 31 | + def visit_call(self, call): |
| 32 | + op = call.op |
| 33 | + if isinstance(op, relay.Function) and "PartitionedFromPattern" in op.attrs: |
| 34 | + self.signature["op_type"] = op.attrs["Composite"] |
| 35 | + for i, arg in enumerate(op.params): |
| 36 | + self.signature["arg%d_shape" % i] = arg.checked_type.shape |
| 37 | + self.signature["arg%d_dtype" % i] = arg.checked_type.dtype |
| 38 | + self.signature["ret_shape"] = op.ret_type.shape |
| 39 | + self.signature["ret_dtype"] = op.ret_type.dtype |
| 40 | + |
| 41 | + |
| 42 | +def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"): |
| 43 | + """Given a module partitioned for CUTLASS offloading, profile each workload to select which |
| 44 | + kernels to emit. |
| 45 | +
|
| 46 | + Parameters |
| 47 | + ---------- |
| 48 | + mod : IRModule |
| 49 | + The Relay module with cutlass partitions. |
| 50 | +
|
| 51 | + sm : int |
| 52 | + An integer specifying the compute capability. For example, 75 for Turing and |
| 53 | + 80 or 86 for Ampere. |
| 54 | +
|
| 55 | + profile_all : bool |
| 56 | + Whether or not profile all candidate kernels, or stop profiling after |
| 57 | + the first applicable kernel is found. |
| 58 | +
|
| 59 | + use_multiprocessing : bool |
| 60 | + Whether or not compile profiler executables for different kernels in parallel. |
| 61 | +
|
| 62 | + tmp_dir : string, optional |
| 63 | + A temporary directory where intermediate compiled artifacts will be stored. |
| 64 | +
|
| 65 | + Returns |
| 66 | + ------- |
| 67 | + mod : IRModule |
| 68 | + The updated module annotated with cutlass profiling information. |
| 69 | +
|
| 70 | + num_cutlass_partition : int |
| 71 | + The number of partitioned functions created for CUTLASS. |
| 72 | + """ |
| 73 | + cutlass_profiler = CutlassGemmProfiler(sm, "../../../3rdparty/cutlass", tmp_dir) |
| 74 | + num_cutlass_partition = 0 |
| 75 | + for var in mod.get_global_vars(): |
| 76 | + fun_name = var.name_hint |
| 77 | + func = mod[fun_name] |
| 78 | + annotator = GemmAnnotator() |
| 79 | + if "cutlass" in fun_name: |
| 80 | + num_cutlass_partition += 1 |
| 81 | + annotator.visit(func) |
| 82 | + # call cutlass profiler to find best settings, update attr |
| 83 | + new_attrs = {} |
| 84 | + new_attrs.update(annotator.signature) |
| 85 | + for key in func.attrs.keys(): |
| 86 | + new_attrs[key] = func.attrs[key] |
| 87 | + # call profiler |
| 88 | + arg0_shape = new_attrs["arg0_shape"] |
| 89 | + arg1_shape = new_attrs["arg1_shape"] |
| 90 | + MM = arg0_shape[0] |
| 91 | + KK = arg0_shape[1] |
| 92 | + NN = arg1_shape[0] |
| 93 | + out = cutlass_profiler.profile( |
| 94 | + MM, NN, KK, annotator.signature["ret_dtype"], profile_all, use_multiprocessing |
| 95 | + ) |
| 96 | + if new_attrs["op_type"] == "cutlass.dense": |
| 97 | + new_attrs["cutlass_op_def"] = out["opdef"] |
| 98 | + elif new_attrs["op_type"] == "cutlass.dense_bias": |
| 99 | + new_attrs["cutlass_op_def"] = out["opdef_bias"] |
| 100 | + elif new_attrs["op_type"] == "cutlass.dense_bias_relu": |
| 101 | + new_attrs["cutlass_op_def"] = out["opdef_bias_relu"] |
| 102 | + elif "cutlass.dense_bias_gelu" in new_attrs["op_type"]: |
| 103 | + new_attrs["cutlass_op_def"] = out["opdef_bias_gelu"] |
| 104 | + else: |
| 105 | + raise ValueError("%s pattern is not implemented." % new_attrs["op_type"]) |
| 106 | + new_attrs["cutlass_op_name"] = out["name"] |
| 107 | + |
| 108 | + print("The best kernel is " + new_attrs["cutlass_op_name"]) |
| 109 | + if new_attrs["cutlass_op_name"].find("_tn_align") > 0: |
| 110 | + new_attrs["lda"] = "K" |
| 111 | + new_attrs["ldb"] = "K" |
| 112 | + new_attrs["ldc"] = "N" |
| 113 | + elif new_attrs["cutlass_op_name"].find("_nt_align") > 0: |
| 114 | + new_attrs["lda"] = "M" |
| 115 | + new_attrs["ldb"] = "N" |
| 116 | + new_attrs["ldc"] = "N" |
| 117 | + else: |
| 118 | + raise ValueError("%s unsupported operation" % new_attrs["cutlass_op_name"]) |
| 119 | + new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) |
| 120 | + new_func = relay.Function( |
| 121 | + func.params, |
| 122 | + func.body, |
| 123 | + ret_type=func.ret_type, |
| 124 | + type_params=func.type_params, |
| 125 | + attrs=new_attrs, |
| 126 | + ) |
| 127 | + mod.update_func(var, new_func) |
| 128 | + |
| 129 | + return mod, num_cutlass_partition |
| 130 | + |
| 131 | + |
| 132 | +def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so"): |
| 133 | + """Compile CUTLASS kernels in lib and return the runtime module ready to run. |
| 134 | +
|
| 135 | + Parameters |
| 136 | + ---------- |
| 137 | + lib : GraphExecutorFactoryModule |
| 138 | + The output from relay.build containing compiled host code and non-cutlass kernels. |
| 139 | +
|
| 140 | + sm : int |
| 141 | + An integer specifying the compute capability. For example, 75 for Turing and |
| 142 | + 80 or 86 for Ampere. |
| 143 | +
|
| 144 | + tmp_dir : string, optional |
| 145 | + A temporary directory where intermediate compiled artifacts will be stored. |
| 146 | +
|
| 147 | + lib_path : string, optional |
| 148 | + The path to a shared library which will be generated as the result of the build process |
| 149 | +
|
| 150 | + Returns |
| 151 | + ------- |
| 152 | + updated_lib : runtime.Module |
| 153 | + The updated module with compiled cutlass kernels. |
| 154 | + """ |
| 155 | + cutlass_path = "../../../3rdparty/cutlass/include" |
| 156 | + cutlass_util_path = "../../../3rdparty/cutlass/tools/util/include" |
| 157 | + |
| 158 | + kwargs = {} |
| 159 | + kwargs["cc"] = "nvcc" |
| 160 | + kwargs["options"] = [ |
| 161 | + "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", |
| 162 | + "-gencode=arch=compute_%d,code=[sm_%d,compute_%d]" % (sm, sm, sm), |
| 163 | + "-Xcompiler=-fPIC", |
| 164 | + "-Xcompiler=-Wconversion", |
| 165 | + "-Xcompiler=-fno-strict-aliasing", |
| 166 | + "-O3", |
| 167 | + "-std=c++14", |
| 168 | + "-I" + cutlass_path, |
| 169 | + "-I" + cutlass_util_path, |
| 170 | + ] |
| 171 | + lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs) |
| 172 | + return runtime.load_module(lib_path) |
0 commit comments