Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,13 +633,17 @@ def __init__(
def handle_conv2d(self, f, op_type):
"""Tune and annotate a conv2d op."""
signature = _extract_relax_function_signature(f)
arg_idx = _extract_arg_idx(op_type, f)
op_attrs = _get_call_node(f.body, "relax.nn.conv2d").attrs

d_shape = signature["arg0_shape"]
w_shape = signature["arg1_shape"]
data_arg = f"arg{arg_idx['lhs']}"
weight_arg = f"arg{arg_idx['rhs']}"

d_shape = signature[f"{data_arg}_shape"]
w_shape = signature[f"{weight_arg}_shape"]
out_shape = signature["ret_shape"]
data_dtype = signature["arg0_dtype"]
weight_dtype = signature["arg1_dtype"]
data_dtype = signature[f"{data_arg}_dtype"]
weight_dtype = signature[f"{weight_arg}_dtype"]
out_dtype = signature["ret_dtype"]
padding = op_attrs["padding"]
strides = op_attrs["strides"]
Expand Down Expand Up @@ -673,6 +677,10 @@ def handle_conv2d(self, f, op_type):
return f.with_attrs(
{
"op_type": op_type,
"data_arg_idx": arg_idx["lhs"],
"weight_arg_idx": arg_idx["rhs"],
"bias_arg_idx": arg_idx.get("bias"),
"residual_arg_idx": arg_idx.get("residual"),
"arg0_dtype": data_dtype,
"arg1_dtype": weight_dtype,
"ret_dtype": out_dtype,
Expand Down
15 changes: 6 additions & 9 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def emit(
return substitute_template(template, values)


def instantiate_conv2d_template(attrs, func_args):
def instantiate_conv2d_template(attrs):
"""Return CUTLASS host code for conv2d based on a template and the provided attribute map."""
template = """
${cutlass_op_def}
Expand Down Expand Up @@ -382,8 +382,8 @@ def instantiate_conv2d_template(attrs, func_args):
cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, split_k_slices);
const cutlass::conv::SplitKMode split_k_mode = cutlass::conv::SplitKMode::${split_k_mode};

void* ptr_a = (void*)(${arg0}->data);
void* ptr_b = (void*)(${arg1}->data);
void* ptr_a = (void*)(${data_arg}->data);
void* ptr_b = (void*)(${weight_arg}->data);
${bias_decl}
${residual_decl}
void* ptr_out = (void*)(out0->data);
Expand Down Expand Up @@ -481,12 +481,12 @@ def instantiate_conv2d_template(attrs, func_args):
aux_map["beta"] = "1"

if has_residual_blcok:
aux_map["bias_decl"] = "void* ptr_bias = (void*)(${arg2}->data);\n"
aux_map["residual_decl"] = "void* ptr_residual = (void*)(${arg3}->data);"
aux_map["bias_decl"] = "void* ptr_bias = (void*)(${bias_arg}->data);\n"
aux_map["residual_decl"] = "void* ptr_residual = (void*)(${residual_arg}->data);"
aux_map["tensor_c"] = "ptr_residual"
aux_map["tensor_c_layout"] = "layout_C"
elif has_bias:
aux_map["bias_decl"] = "void* ptr_c_bias = (void*)(${arg2}->data);\n"
aux_map["bias_decl"] = "void* ptr_c_bias = (void*)(${bias_arg}->data);\n"
aux_map["residual_decl"] = ""
aux_map["tensor_c"] = "ptr_c_bias"
aux_map["tensor_c_layout"] = "cutlass::layout::TensorNHWC::Stride(0)"
Expand Down Expand Up @@ -534,7 +534,4 @@ def instantiate_conv2d_template(attrs, func_args):

template = substitute_template(template, aux_map)

for i, arg in enumerate(func_args):
attrs["arg{}".format(i)] = arg

return substitute_template(template, attrs)
34 changes: 23 additions & 11 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from tvm.tir import IntImm

from . import _ffi_api as ffi
from .attention_operation import instantiate_attention_template
from .conv2d_operation import instantiate_conv2d_template
from .gemm_operation import instantiate_gemm_template
from .attention_operation import instantiate_attention_template
from .library import (
DataType,
DataTypeSize,
Expand Down Expand Up @@ -463,9 +463,9 @@ def __init__(self, code, headers):

def _get_optional_int_annotation(annotations, key, default=None):
value = annotations.get(key, default)
if value is not None:
return int(value)
return value
if value is None:
return default
return int(value)


@tvm._ffi.register_func("contrib.cutlass.instantiate_template")
Expand Down Expand Up @@ -614,10 +614,21 @@ def get_batch_on_arg(arg_name, arg_shape):
return CodegenResult(code, headers)

elif "conv2d" in func_name:
activation_shape = arg0_shape
weight_shape = arg1_shape
data_arg_idx = _get_optional_int_annotation(annotations, "data_arg_idx", 0)
weight_arg_idx = _get_optional_int_annotation(annotations, "weight_arg_idx", 1)
bias_arg_idx = _get_optional_int_annotation(annotations, "bias_arg_idx", 2)
residual_arg_idx = _get_optional_int_annotation(annotations, "residual_arg_idx", 3)

attrs["data_arg"] = func_args[data_arg_idx]
attrs["weight_arg"] = func_args[weight_arg_idx]
if len(func_args) > bias_arg_idx:
attrs["bias_arg"] = func_args[bias_arg_idx]
if len(func_args) > residual_arg_idx:
attrs["residual_arg"] = func_args[residual_arg_idx]

activation_shape = annotations[f"arg{data_arg_idx}_shape"]
weight_shape = annotations[f"arg{weight_arg_idx}_shape"]
output_shape = annotations["ret_shape"]
activation_var = func_args[0]

if "conv2d_transpose" in func_name:
headers.append("cutlass/conv/kernel/default_conv2d_dgrad.h")
Expand All @@ -643,9 +654,10 @@ def get_batch_on_arg(arg_name, arg_shape):
"cutlass/reduction/thread/reduction_operators.h",
]

attrs["N"] = get_dim(activation_shape[0], activation_var, 0)
attrs["H"] = get_dim(activation_shape[1], activation_var, 1)
attrs["W"] = get_dim(activation_shape[2], activation_var, 2)
data_arg = attrs["data_arg"]
attrs["N"] = get_dim(activation_shape[0], data_arg, 0)
attrs["H"] = get_dim(activation_shape[1], data_arg, 1)
attrs["W"] = get_dim(activation_shape[2], data_arg, 2)
attrs["C"] = activation_shape[3]
attrs["P"] = get_dim(output_shape[1], "out0", 1)
attrs["Q"] = get_dim(output_shape[2], "out0", 2)
Expand All @@ -666,7 +678,7 @@ def get_batch_on_arg(arg_name, arg_shape):
attrs["split_k_mode"] = "kSerial"
attrs["split_k_slices"] = 1

code = instantiate_conv2d_template(attrs, func_args)
code = instantiate_conv2d_template(attrs)
return CodegenResult(code, headers)

elif "attention" in func_name:
Expand Down
Loading