Skip to content

Commit

Permalink
use DenseInferLayout for matmul
Browse files Browse the repository at this point in the history
Change-Id: I980d9ff0ed842f1b8176057ee070779427b0a896
  • Loading branch information
lhutton1 committed May 8, 2024
1 parent 59671d4 commit 4b8d83c
Showing 1 changed file with 1 addition and 16 deletions.
17 changes: 1 addition & 16 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,21 +178,6 @@ Expr MakeMatmul(Expr tensor_a, Expr tensor_b, IndexExpr units, DataType out_dtyp
return Call(matmul_op, {tensor_a, tensor_b}, Attrs(attrs), {});
}

InferCorrectLayoutOutput MatmulInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* params = attrs.as<MatmulAttrs>();
ICHECK(params);

bool transpose_a = params->transpose_a;
bool transpose_b = params->transpose_b;
String layout_a = transpose_a ? "CN" : "NC";
String layout_b = transpose_b ? "CN" : "NC";

return InferCorrectLayoutOutput({layout_a, layout_b}, {"NC"}, attrs);
}

TVM_REGISTER_GLOBAL("relay.op.nn._make.matmul").set_body_typed(MakeMatmul);

RELAY_REGISTER_OP("nn.matmul")
Expand All @@ -208,7 +193,7 @@ RELAY_REGISTER_OP("nn.matmul")
.add_argument("tensor_a", "nD Tensor", "The first input Tensor.")
.add_argument("tensor_b", "2D Tensor", "The second input Tensor.")
.set_support_level(1)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", MatmulInferCorrectLayout)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", DenseInferCorrectLayout)
.add_type_rel("Matmul", MatmulRel<MatmulAttrs>)
.set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);

Expand Down

0 comments on commit 4b8d83c

Please sign in to comment.