Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft][Unity] Allow dynamic indices to TupleGetItem #16002

Closed

Conversation

Lunderberg
Copy link
Contributor

This PR updates the type of TupleGetItem::index from int to Expr, to allow access of a tuple at a location specified by a symbolic variable. The lack of this functionality was run into multiple times while implementing pre-sharded model weights (MLC-LLM PR link). While that case could be worked around by using R.strided_slice instead of R.split, it isn't as clean of a solution.

This PR is currently marked as a draft, as it passes the unit tests that it adds, it is likely to cause breakage elsewhere and may require backwards compatibility updates.

@@ -338,7 +338,7 @@ class TupleGetItemNode : public ExprNode {
/*! \brief The tuple Expression */
Expr tuple;
/*! \brief which value to get */
int index;
Expr index;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the symbolic fetching case, we should use a lower level instrinsic. This is mainly to simplify the overall constant accessing of the tuple elements(which is the common case)

@tqchen
Copy link
Member

tqchen commented Oct 30, 2023

Given that there is a great amount of difference for a pass, and deduction to handle the symbolic index and constant index, in this particular case it is better to restrict the tuple element acess to constant index.

We can introduce another op, or like strided slice to handle the dynamic case

@Lunderberg
Copy link
Contributor Author

This was brought about by a use case where I was unable to express a contiguous slice of a tensor, and needed to express a non-strided slice using the strided slice operator. The intent of "the n-th of 4 equal slices" is much more clearly expressed as R.slice(tensor, 4, axis=0)[slice_index] expresses much more clearly than R.strided_slice(tensor, axes=[0], begin=[tensor.struct_info.shape[0]*slice_index // 4], end=[tensor.struct_info.shape[0]*(slice_index+1) // 4], assume_inbound=True). While I was able to work around the Relax limitation by using the latter form, the simpler expression should be writeable.

@tqchen
Copy link
Member

tqchen commented Oct 30, 2023

we can create sugar functions in front end that sugars to the related expression. The comment is mainly about ensuring the AST stays the same.

@Lunderberg
Copy link
Contributor Author

I don't think I'm as concerned about the C++ types used in the AST, especially for a newer IR like Relax, as compared to the expressibility of user intent. Because R.slice(tensor, 4, axis=0)[slice_index] is a clear expression of user intent, it has the correct shape inference if a later change mutates the tensor argument to have a different shape. On the other hand, even if the R.strided_slice implementation is produced through a front-end de-sugaring, it would be ambiguous whether the indices were intended as "split into 4 equal parts" or as "split at these specific indices". Resolving that ambiguity is necessary for correct shape inference.

@tqchen
Copy link
Member

tqchen commented Oct 31, 2023

To be clear, I think it is OK to support that particular syntax and dispatch to an intrinsic operator. https://github.com/apache/tvm/blob/unity/src/runtime/relax_vm/builtin.cc#L470.

If the particular intent is common, perhaps we would like to think about how to represent that compound operator. My understanding is that this is needed for perhaps slicing up weights, but not necessarily important for optimization. So if existing ops works, it would be a better approach.

The main consideration to keep TupleGetItem simple, since most pass handles the constant subscription and we want to contain that complexity. There is a tradeoff in keeping the overall AST general versus making it easy for developers to develop passes access and operate on a normal form that still can capture the same computation.

The original rationale of the tuple and the getitem counterpart is for us to be able to access compile time structures and deduce them in a lossless fashion. This is the common case in most of the scenarios.

Symbolic index can be useful, and is the less common case in most of the scenarios.

Tying the AST of two cases together would inevitably bring extra metal overhead to pass writers. Many(perhaps myself included) would need to extra extra attention when write most of the nested tuple propagation as easily or correctly and need to think about constant specialization.

Having an separate intrinsic would help alleviate this issue, since we know that the other case would require less optimizaiton, and the passes can focus on common cases

return index == 0 ? nlayouts.LeafValue() : LayoutDecision("");
auto int_index = Downcast<PrimStructInfo>(index->struct_info_)
->value.as<IntImm>()
.value_or(Integer(0))
Copy link
Member

@tqchen tqchen Oct 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemetations like this actually tries to generalize, but leaves more uncertainties in behavior in the code, e.g. what if we have a Tuple[Tuple[Tensor, Tensor], Tensor]? This is one of the reason why keeping the data structure contained is much more desirable.

It is better to "do less", but ensure what we can handle is robust. Since over generalize can cause unexpected errors that are harder to manage.

For layout prop pass, the right scope is to handle handling constant tuple indices well, while leave out other cases as opaque. Or perhaps do more effort in pattern match some special cases where the tuple/array have a homogenous type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implemetations like this actually tries to generalize, but leaves more uncertainties in behavior in the code, e.g. what if we have a Tuple[Tuple[Tensor, Tensor], Tensor]?

Here, I maintained the existing behavior where the default indicates that we don't have a known decision, and cannot copy a previous layout decision. For an unknown index,

For layout prop pass, the right scope is to handle handling constant tuple indices well, while leave out other cases as opaque.

I'm not opposed to this handling either. My aim for this PR was to extend the IR, and to handle cases that could be clearly extended. Where it wasn't clearly extendable, I marked it as a current limitation of the pass, to only support cases where the exact value of the index is known.

Since there's questions on this implementation, it's definitely not a clear generalization, and I'll update to restrict the usage to known indices.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is not about requesting changes. Mainly to illustrate that there is a gap and added burden if we go with the generalized TupleGetItem route. So the suggestion is to introduce a separate operator for such cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I've tended to see it as an added burden that must exist somewhere in the library. The burden can either be placed at the producer's side, at the consumer's side, or as part of a utility, but can't be removed entirely without also removing expressibility. Exactly where that burden is best placed depends on the specific case.

I especially like the separate operator in this case, because the default behavior for a Call node (if a tuple is passed as an argument, assume all elements of the tuple are touched) is also the correct behavior for access of a tuple at an unknown location (assume any element of the tuple may be touched). Everything else can be handled by existing utility functions without imposing a burden at either the producer or consumer.

@tqchen
Copy link
Member

tqchen commented Nov 1, 2023

How about we enable an op relax.tuple_getitem_dyn that allows dynamic indices, which allows us to separate these two while still clearly represent the intent

@Lunderberg
Copy link
Contributor Author

How about we enable an op relax.tuple_getitem_dyn that allows dynamic indices, which allows us to separate these two while still clearly represent the intent

I think that makes sense. Sketching out the steps that would be required, while it's fresh in my mind.

  • A C++ utility function Expr tuple_get_item(Expr tuple, index)

    • The index argument is overloaded to accept int index, PrimExpr index, or relax::Expr index.
    • Type checking, the dtype of PrimExpr index or relax::Expr index must be int64.
    • A statically-known index (int index, PrimExpr index with IntImm, or relax::Expr index with known IntImm value in PrimStructInfo) returns a TupleGetItem node from tuple_get_item.
    • Anything else returns a relax.tuple_getitem_dyn from tuple_get_item.
  • Update the relax.Expr.__getitem__ function to call tuple_get_item

  • (QoL) Implement operator[] for Expr to delegate to the tuple_get_item function.

  • Update the TVMScript printer to print check for relax.tuple_getitem_dyn, print as my_tuple[index].

  • Add an operator attribute FNormalize. This would be a TypedPackedFunc<Expr(const BlockBuilder&, Expr)>, and would be called during Expr Normalizer::VisitExpr_(const CallNode*). This would be a general-purpose utility that allows operators to specify the normalization that they require.

  • Implement FNormalize for relax.tuple_getitem_dyn. In case of a statically-known index, it would be normalized to a TupleGetItem node. Doing this as part of normalization would allow pattern-matching passes to continue checking for TupleGetItem explicitly.

@tqchen
Copy link
Member

tqchen commented Nov 2, 2023

Looks good, FNormalize seems to be not a hard dependency as most cases the code is already normalized(per parser overloading). It is indeed useful to include this in the canonicalize pass.

@Lunderberg
Copy link
Contributor Author

Looks good, FNormalize seems to be not a hard dependency as most cases the code is already normalized(per parser overloading). It is indeed useful to include this in the canonicalize pass.

Agreed, not a hard dependency. The functionality is more to allow these normalization requirements to exist in a centralized location, without requiring individual passes to be aware of them. It's functionality that I could also see being useful for other operators that have AST requirements beyond those expressed in the C++ type system.

@Lunderberg Lunderberg force-pushed the unity_tuple_get_item_at_expr branch 2 times, most recently from 1b3b5bf to dd02245 Compare November 6, 2023 19:34
@Lunderberg
Copy link
Contributor Author

I've re-implemented this change using a builtin tuple_get_item_dyn, and I'm really liking how well it integrates with FNormalize. The normalization ensures that we use TupleGetItem whenever possible, to expose static indices, while still exposing user-friendly generating and printing of the tuple_get_item_dyn operator.

In addition to #16067, the test cases also rely on #15983 to handle propagation of primitives within a tuple owned by the RelaxVM.

@Lunderberg Lunderberg force-pushed the unity_tuple_get_item_at_expr branch 2 times, most recently from fd61801 to 1fb2578 Compare November 7, 2023 14:19
@Lunderberg
Copy link
Contributor Author

Rebased onto unity to include #16067. Will still require #15983 in order to pass unit tests, so it remains a draft for now.

TVM containers, such as tvm::runtime::Array, require the contained
objects to inherit from `ObjectRef`.  As a result, the wrapper types
`IntImm`, `FloatImm`, and `StringImm` are often used to allow native
types in the TVM containers.  Conversions into these wrapper type may
be required when using a container, and may be performed automatically
when passing an object across the FFI.  By also providing conversion
to an unwrapped type, these automatic conversions are transparent
become transparent to users.

The trait can be specialized to add type specific conversion logic
from the TVMArgvalue and TVMRetValue.
Because `isinstance(bool_value, int)` returns True, boolean values
were being converted to `T.int64`, instead of to `T.bool`.
Prior to this commit, the `Array::Map` member function could only be
applied to nullable object types.  This was due to the internal use of
`U()` as the default value for initializing the output `ArrayNode`, where
`U` is the return type of the mapping function.  This default
constructor is only available for nullable types, and would result in
a compile-time failure for non-nullable types.

This commit replaces `U()` with `ObjectRef()` in `Array::Map`,
removing this limitation.  Since all items in the output array are
overwritten before returning to the calling scope, initializing the
output array with `ObjectRef()` does not violate type safety.
Prior to this commit, `int`, `float`, and `bool` arguments from Python
were converted to `IntImm`, `FloatImm`, and `Bool`.  These are
subtypes of `PrimExpr`, and should only be used at compile-time.  By
automatically applying this conversion as part of the FFI, these types
are required to be present whenever a primitive is converted to a
`tvm::ObjectRef`.

This can become especially fragile for an end-user when storing
objects into a TVM container.  Because TVM containers require all
contents to be `ObjectRef` subclasses, an automatic conversion may be
applied on storing into a container, resulting in an unexpected type
being retrieved from the container.  For example, this currently
occurs in Relax when extracting a `R.Prim` from a `R.Tuple`.

This commit introduces a `Box<T>` type for storage of boxed primitives
at runtime, distinct from the IR types.

* Primitive arguments provided to a PackedFunc that requires an
  `ObjectRef` will be converted to the corresponding boxed type.
  (e.g. Passing a Python `int` to a C++ function accepting `ObjectRef`
  produces a `Box<int64_t>`.

* Boxed primitives provided to a PackedFunc that requires an unboxed
  primitive will be converted to the corresponding primitive.

* PackedFunc return values of `ObjectRef` are converted to the
  corresponding primitive, if present.  (e.g. If a `tuple_getitem`
  with static return type `ObjectRef` returns a `Box<int64_t>`, it
  will be unwrapped to a python `int`.)

Together, these three rules provide backwards compatibility for
existing PackedFunc definitions, while avoiding exposing the user to
any container-induced type conversions betweeen primitive types and
`ObjectRef`.
Mostly, this requires removing `.value` unwrapping that is now applied
automatically.
* Change tir.Call signature accept `Array<Variant<...>>` instead of
  `Array<ObjectRef>`.  This allows the FFI to apply registered
  conversions.

* Update target parsing to expect the default object types.

* Extend conversion into PrimExpr.  Several APIs that check for a
  PrimExpr may now receive a `runtime.String`, `runtime.Box<bool>` or
  `runtime.Box<int64_t>`.  These must be converted to `StringImm`,
  `Bool`, or `IntImm` for use in functions that accept `PrimExpr`.
@Lunderberg
Copy link
Contributor Author

The PR has been re-implemented to use functionality from #16183, instead of the now closed PR #15983. The new #16183 provides the same functionality required by this PR, allowing a relax tuple's runtime representation as Array<ObjectRef> to be transparently unwrapped to an integer when accessed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants