Skip to content

Commit

Permalink
Annotate purity for remaining operators
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Mar 26, 2023
1 parent 3b42929 commit 1953c74
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 13 deletions.
24 changes: 16 additions & 8 deletions src/relax/op/tensor/create.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ TVM_REGISTER_OP("relax.full")
.add_argument("shape", "Shape", "The shape of the created tensor.")
.add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFull)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow);
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.full_like */
Expr full_like(Expr x, Expr fill_value, DataType dtype) {
Expand Down Expand Up @@ -121,7 +122,8 @@ TVM_REGISTER_OP("relax.full_like")
.add_argument("x", "Tensor", "The input tensor.")
.add_argument("fill_value", "Tensor", "The scalar value to fill.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFullLike)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow);
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

// Structure info inference for ones and zeros
StructInfo InferStructInfoOnesZeros(const Call& call, const BlockBuilder& ctx) {
Expand Down Expand Up @@ -178,13 +180,15 @@ TVM_REGISTER_OP("relax.ones")
.set_num_inputs(1)
.add_argument("shape", "Shape", "The shape of the created tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesZeros)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow);
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

TVM_REGISTER_OP("relax.ones_like")
.set_attrs_type<InitAttrs>()
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesLikeZerosLike);
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesLikeZerosLike)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.zeros & relax.zeros_like */
Expr zeros(Expr shape, DataType dtype) {
Expand All @@ -211,13 +215,15 @@ TVM_REGISTER_OP("relax.zeros")
.set_num_inputs(1)
.add_argument("shape", "Shape", "The shape of the created tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesZeros)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow);
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

TVM_REGISTER_OP("relax.zeros_like")
.set_attrs_type<InitAttrs>()
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesLikeZerosLike);
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoOnesLikeZerosLike)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.tril & relax.triu */
TVM_REGISTER_NODE_TYPE(TriluAttrs);
Expand Down Expand Up @@ -256,13 +262,15 @@ TVM_REGISTER_OP("relax.tril")
.set_attrs_type<TriluAttrs>()
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTrilTriu);
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTrilTriu)
.set_attr<Bool>("FPurity", Bool(true));

TVM_REGISTER_OP("relax.triu")
.set_attrs_type<TriluAttrs>()
.set_num_inputs(1)
.add_argument("x", "Tensor", "The input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTrilTriu);
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTrilTriu)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
6 changes: 4 additions & 2 deletions src/relax/op/tensor/datatype.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ TVM_REGISTER_OP("relax.astype")
.add_argument("x", "Tensor", "The input tensor")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAstype)
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutUnaryEwise)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow);
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.wrap_param */
TVM_REGISTER_NODE_TYPE(WrapParamAttrs);
Expand All @@ -83,7 +84,8 @@ TVM_REGISTER_OP("relax.wrap_param")
.set_attrs_type<WrapParamAttrs>()
.set_num_inputs(1)
.add_argument("data", "Tensor", "The input tensor")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoWrapParam);
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoWrapParam)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
6 changes: 4 additions & 2 deletions src/relax/op/tensor/index.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ TVM_REGISTER_OP("relax.take")
.set_num_inputs(2)
.add_argument("x", "Tensor", "The source tensor.")
.add_argument("indices", "Tensor", "The indices of the values to extract.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTake);
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoTake)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.strided_slice */
TVM_REGISTER_NODE_TYPE(StridedSliceAttrs);
Expand Down Expand Up @@ -232,7 +233,8 @@ TVM_REGISTER_OP("relax.strided_slice")
.add_argument("x", "Tensor", "The source tensor to be sliced.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoStridedSlice)
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutStridedSlice)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow);
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm
3 changes: 2 additions & 1 deletion src/relax/op/tensor/linear_algebra.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ TVM_REGISTER_OP("relax.matmul")
.add_argument("x2", "Tensor", "The second input tensor.")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMatmul)
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways)
.set_attr<FInferMixedPrecision>("FInferMixedPrecision", InferMixedPrecisionMatmul);
.set_attr<FInferMixedPrecision>("FInferMixedPrecision", InferMixedPrecisionMatmul)
.set_attr<Bool>("FPurity", Bool(true));

} // namespace relax
} // namespace tvm

0 comments on commit 1953c74

Please sign in to comment.