Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Add support for AXIS_SEPARATOR in AlterOpImpl Pass #15315

Merged
merged 2 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions include/tvm/relax/attrs/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,21 @@ struct LayoutTransformAttrs : public tvm::AttrsNode<LayoutTransformAttrs> {
// pad_value is chosen to be of PrimValue type, as it represents constant TIR POD expression. This
// needs to be revisited in case PrimValue is evolved to represent symbolic expression in future.
Optional<PrimValue> pad_value;
/*!
* axis_separators between input axes when generating flattened output axes. For buffers
* representing flat 1-d memory (e.g. any buffer in RAM), this should be an empty array.
* For buffers representing non-flat memory, each entry in axis_separators should be the
* first input axis that is part of a new flattened axis.
*/
Optional<Array<IntImm>> axis_separators;

TVM_DECLARE_ATTRS(LayoutTransformAttrs, "relax.attrs.LayoutTransformAttrs") {
TVM_ATTR_FIELD(index_map).describe("The layout transformation to apply.");
TVM_ATTR_FIELD(pad_value).describe(
"The specific value to be used to pad if the layout transform would result in implicit "
"padding. If not specified, the compiler is free to choose any value.");
TVM_ATTR_FIELD(axis_separators)
.describe("The separators between input axes when generating flat output axes");
}
}; // struct LayoutTransformAttrs

Expand Down
4 changes: 3 additions & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,12 @@ TVM_DLL Pass DecomposeOpsForTraining(Optional<String> func_name);
* \param op_impl_map Map from from kOperatorName attr (e.g., relax.conv2d) to replacement PrimFunc
* \param op_buffer_transforms Map from kOperatorName attr to layout transformations on each of the
* PrimFunc i/o buffers.
* \param axis_separators Map from kOperatorName attr to axis_separators of each buffer_transforms
* \return The Pass.
*/
TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<tir::IndexMap>>& op_buffer_transforms);
const Map<String, Array<tir::IndexMap>>& op_buffer_transforms,
const Map<String, Array<Array<IntImm>>>& axis_separators);

/*!
* \brief Layout conversion pass.
Expand Down
10 changes: 9 additions & 1 deletion python/tvm/relax/op/manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def layout_transform(
x: Expr,
index_map: Union[Callable, IndexMap],
pad_value: Optional[Union[int, float, PrimValue]] = None,
axis_separators: Optional[Union[int, IndexMap.AXIS_SEPARATOR]] = None,
):
"""Modifies the layout of a tensor.

Expand All @@ -129,6 +130,9 @@ def layout_transform(
The value used for padding if the transformation results in implicit padding.
If not specified, any value can be used.

axis_separators : Optional[Union[int, IndexMap.AXIS_SEPARATOR]]
The axis_separators for index_map to create non flat buffers.

Returns
-------
result : relax.Expr
Expand All @@ -150,7 +154,11 @@ def layout_transform(
elif "float" in x_dtype and (isinstance(pad_value, (int, float))):
pad_value = FloatImm(x_dtype, float(pad_value))
pad_value = PrimValue(pad_value)
return _ffi_api.layout_transform(x, index_map, pad_value) # type: ignore

if axis_separators is None:
axis_separators = []

return _ffi_api.layout_transform(x, index_map, pad_value, axis_separators) # type: ignore


def permute_dims(x: Expr, axes: Optional[List[int]] = None) -> Expr:
Expand Down
8 changes: 7 additions & 1 deletion python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,6 +899,7 @@ def DecomposeOpsForTraining(func_name: Optional[str] = None) -> tvm.ir.transform
def AlterOpImpl(
op_impl_map: Dict[str, PrimFunc],
op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]],
op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]],
):
"""Replace all PrimFunc's which have matching 'operator_name' attribute, with replacement
PrimFunc that could possibly have different layouts on i/o buffers. The layout
Expand All @@ -912,6 +913,9 @@ def AlterOpImpl(
op_kind to PrimFunc map
op_buffer_transforms: Dict[str, List[Union[IndexMap, Callable]]
op_kind to layout transformation map for each of the buffers
op_buffer_axis_separators: Dict[str, List[Union[IndexMap.AXIS_SEPARATOR, Callable]]]
op_kind to axis_separator for each index_map

Returns
-------
ret: tvm.ir.transform.Pass
Expand All @@ -924,7 +928,9 @@ def AlterOpImpl(
l.append(transform)
op_buffer_transforms[operator_name] = l

return _ffi_api.AlterOpImpl(op_impl_map, op_buffer_transforms) # type: ignore
return _ffi_api.AlterOpImpl(
op_impl_map, op_buffer_transforms, op_buffer_axis_separators
) # type: ignore


def ConvertLayout(desired_layouts: Dict[str, List[str]]) -> tvm.ir.transform.Pass:
Expand Down
4 changes: 3 additions & 1 deletion src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -425,10 +425,12 @@ TVM_REGISTER_OP("relax.flatten")
/* relax.layout_transform */
TVM_REGISTER_NODE_TYPE(LayoutTransformAttrs);

Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value) {
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value,
Optional<Array<IntImm>> axis_separators) {
ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
attrs->index_map = std::move(index_map);
attrs->pad_value = std::move(pad_value);
attrs->axis_separators = std::move(axis_separators);

static const Op& op = Op::Get("relax.layout_transform");
return Call(op, {std::move(x)}, Attrs{attrs}, {});
Expand Down
5 changes: 4 additions & 1 deletion src/relax/op/tensor/manipulate.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,12 @@ Expr flatten(Expr x);
* \param index_map The transformation to apply.
* \param pad_value The value used for padding if the transformation results in implicit padding. If
* not specified, any value can be used.
* \param axis_separators Array of values to differentiate between input axes
* when generating flattened output axes.
* \return The transformed result.
*/
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value);
Expr layout_transform(Expr x, tir::IndexMap index_map, Optional<PrimValue> pad_value,
Optional<Array<IntImm>> axis_separators);

/*!
* \brief Permutes the dimensions of an array.
Expand Down
55 changes: 41 additions & 14 deletions src/relax/transform/alter_op_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@ bool IsTransformBijective(const Expr& expr, const IndexMap& transform) {
class AlterOpImplMutator : public ExprMutator {
public:
AlterOpImplMutator(const IRModule& mod, const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<IndexMap>>& op_buffer_transforms_)
const Map<String, Array<IndexMap>>& op_buffer_transforms_,
const Map<String, Array<Array<IntImm>>>& axis_separators_)
: ExprMutator(mod),
mod_(mod),
op_impl_map_(op_impl_map),
op_buffer_transforms__(op_buffer_transforms_) {}
op_buffer_transforms__(op_buffer_transforms_),
op_buffer_axis_separators__(axis_separators_) {}

IRModule Run() {
for (const auto& [gv, func] : mod_->functions) {
Expand Down Expand Up @@ -119,7 +121,10 @@ class AlterOpImplMutator : public ExprMutator {
const auto& replacement_func = op_impl_map_[op_kind];

Array<IndexMap> buffer_transforms;
Optional<Array<Array<IntImm>>> axis_separators;
if (op_buffer_transforms__.count(op_kind)) buffer_transforms = op_buffer_transforms__[op_kind];
if (op_buffer_axis_separators__.count(op_kind))
axis_separators = op_buffer_axis_separators__[op_kind];

ICHECK(buffer_transforms.empty() || buffer_transforms.size() == replacement_func->params.size())
<< "Either the i/o buffers do not require any transformations or transformations for each "
Expand All @@ -130,15 +135,15 @@ class AlterOpImplMutator : public ExprMutator {
GlobalVar replacement_gv = GetOrCreateGlobalVarForFunc(replacement_func, op_kind);

auto call_tir_inputs_tuple = GetRef<Tuple>(call->args[1].as<TupleNode>());
Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms);
Tuple updated_inputs = UpdateInputs(call_tir_inputs_tuple, buffer_transforms, axis_separators);

ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir sinfo_args.size() is expected to be 1";
StructInfo updated_ret_sinfo = UpdateStructInfo(call->sinfo_args[0], buffer_transforms);
auto updated_call = builder_->Normalize(
Call(call_tir_op_, {replacement_gv, updated_inputs}, call->attrs, {updated_ret_sinfo}));

// Now transform each of the outputs to previous layout.
return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0]);
return TransformOutputs(updated_call, buffer_transforms, call->sinfo_args[0], axis_separators);
}

Array<TensorStructInfo> GetTensorStructInfoPerOutput(const StructInfo& output_sinfo) {
Expand All @@ -157,17 +162,20 @@ class AlterOpImplMutator : public ExprMutator {
return arr_tensor_sinfo;
}

Expr TransformLayout(const Expr& expr, const IndexMap& index_map) {
Expr TransformLayout(const Expr& expr, const IndexMap& index_map,
const Array<IntImm> axis_separators) {
ObjectPtr<LayoutTransformAttrs> attrs = make_object<LayoutTransformAttrs>();
// We want to avoid two layout_transform ops to share the same index map even if they are
// identical. The scope of vars used in index map initial indices is local to the op. Not doing
// so would confuse the structural equality check.
attrs->index_map = std::move(DeepCopyIndexMap(index_map));
attrs->axis_separators = std::move(axis_separators);
return Call(layout_transform_op_, {expr}, Attrs{std::move(attrs)}, {});
}

Expr TransformLayoutInverse(const Expr& expr, const IndexMap& index_map,
const TensorStructInfo& old_tensor_sinfo) {
const TensorStructInfo& old_tensor_sinfo,
const Array<IntImm>& axis_separator) {
Array<PrimExpr> old_shape = GetShapeFromTensorStructInfo(old_tensor_sinfo);
Array<Range> initial_ranges = ConstructRangeFromShape(old_shape);
arith::Analyzer analyzer;
Expand All @@ -177,7 +185,7 @@ class AlterOpImplMutator : public ExprMutator {
<< "Only bijective transformations on input/output buffers are supported, but found "
"padding predicate "
<< padding_predicate << " on initial range " << initial_ranges;
return TransformLayout(expr, inverse_index_map);
return TransformLayout(expr, inverse_index_map, axis_separator);
}

/*!
Expand All @@ -202,16 +210,22 @@ class AlterOpImplMutator : public ExprMutator {
/*!
* \brief Updates call inputs with layout transformed inputs
*/
Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms) {
Tuple UpdateInputs(const Tuple& inputs, const Array<IndexMap>& transforms,
const Optional<Array<Array<IntImm>>>& axis_separators) {
if (transforms.empty()) return inputs;

Array<Expr> updated_inputs;
int index = 0;
for (const auto& input : inputs->fields) {
Array<IntImm> axis_separator;
if (axis_separators.defined()) {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_separator = axis_separators_value[index];
}
auto transform = transforms[index++];
ICHECK(IsTransformBijective(input, transform))
<< "Non bijective transforms on input and output buffers are not supported.";
updated_inputs.push_back(TransformLayout(input, transform));
updated_inputs.push_back(TransformLayout(input, transform, axis_separator));
}
return Tuple(updated_inputs);
}
Expand Down Expand Up @@ -254,29 +268,39 @@ class AlterOpImplMutator : public ExprMutator {
}

Expr TransformOutputs(const Expr& expr, const Array<IndexMap>& buffer_transforms,
const StructInfo& old_struct_info) {
const StructInfo& old_struct_info,
const Optional<Array<Array<IntImm>>>& axis_separators) {
if (buffer_transforms.empty()) return expr;

Array<TensorStructInfo> old_output_sinfo = GetTensorStructInfoPerOutput(old_struct_info);

Array<IntImm> axis_sep;
size_t num_outputs = old_output_sinfo.size();
if (num_outputs == 0) return expr;

size_t first_output_index = buffer_transforms.size() - num_outputs;
// If there is a single output, return the transformed output.
if (num_outputs == 1) {
IndexMap output_map = buffer_transforms[first_output_index];
return TransformLayoutInverse(expr, output_map, old_output_sinfo[0]);
if (axis_separators.defined()) {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_sep = axis_separators_value[first_output_index];
}
return TransformLayoutInverse(expr, output_map, old_output_sinfo[0], axis_sep);
}

// In case of more than one output, we would have to get each item of the output tuple,
// transform it and return a tuple of all transformed outputs.
Array<Expr> transformed_outputs;
for (size_t i = 0; i + first_output_index < buffer_transforms.size(); ++i) {
const auto& output_map = buffer_transforms[i + first_output_index];
if (axis_separators.defined()) {
Array<Array<IntImm>> axis_separators_value = axis_separators.value();
axis_sep = axis_separators_value[i + first_output_index];
}
auto output = builder_->Normalize(TupleGetItem(expr, static_cast<int>(i)));
transformed_outputs.push_back(
TransformLayoutInverse(output, output_map, old_output_sinfo[i]));
TransformLayoutInverse(output, output_map, old_output_sinfo[i], axis_sep));
}
return Tuple(transformed_outputs);
}
Expand All @@ -290,6 +314,8 @@ class AlterOpImplMutator : public ExprMutator {
const Map<String, PrimFunc>& op_impl_map_;
/*! \brief Map from kOperatorName attribute to the layout transforms on i/o buffers */
const Map<String, Array<IndexMap>>& op_buffer_transforms__;
/*! \brief Map from kOperatorName attribute to the axis separatos on i/o buffers */
const Map<String, Array<Array<IntImm>>>& op_buffer_axis_separators__;

const Op& call_tir_op_ = Op::Get("relax.call_tir");
const Op& layout_transform_op_ = Op::Get("relax.layout_transform");
Expand All @@ -298,10 +324,11 @@ class AlterOpImplMutator : public ExprMutator {
namespace transform {

Pass AlterOpImpl(const Map<String, tir::PrimFunc>& op_impl_map,
const Map<String, Array<IndexMap>>& op_buffer_transforms_) {
const Map<String, Array<IndexMap>>& op_buffer_transforms_,
const Map<String, Array<Array<IntImm>>>& axis_separators_) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
PassContext pc) {
return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_).Run();
return AlterOpImplMutator(mod, op_impl_map, op_buffer_transforms_, axis_separators_).Run();
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
Expand Down