Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Implemented Variant<...> container #15672

Merged
merged 3 commits into from
Sep 13, 2023

Conversation

Lunderberg
Copy link
Contributor

This commit introduces a new container, Variant, which is analogous to the std::variant introduced in C++17, the enum in Rust, or a tagged union in C. The Variant class is templated over the types that it may contain (e.g. Variant<String, Expr>), where each type is a distinct option that can be stored within the container.

Variant is implemented as a subclass of ObjectRef with no additional data members, similar to the implementation of Optional<T>. It can be constructed from any of its contained types, and the contents can be inspected using the usual my_object.as<T>() and Downcast<T>(my_object) methods. This is intended to allow for drop-in replacement of ObjectRef with Variant<Type1, Type2, ...> in places that previously used a common base class.

To ensure that each variant can be uniquely retrieved, no type stored within the variant may inherit from any other type within the variant. This condition is checked at compile-time, with a static_assert explaining the limitation. This condition is necessary to mimic the semantics of std::variant, whose active member depends on the compile-time type of an object. Without this condition, the expression Variant<PrimExpr, tir::Var> variant = PrimExpr(...) could populate either of the variants depending on the run-time type of an object. Because the Variant class is primarily intended for use when two types do not already inherit from each other, this limitation is not expected to limit its utility.

There are several locations within the TVM codebase where this pattern may be useful, and which are currently worked around various strategies. (This PR does not alter any existing implementations, instead introducing the Variant container that can be used in subsequent PRs, if desired.)

  • Workaround: Store a common base class. For example, the type of relax::TensorStructInfoNode::shape is Optional<Expr>, with a comment stating that it should be only NullOpt, ShapeExpr, or Var. However, these restrictions are not checked by the compiler, and a developer could erroneously provide a different type. By expressing the type as as Optional<Variant<Var,ShapeExpr>>, these errors could be automatically caught.

  • Workaround: Use additional data structures. For example, a PrimFunc parameter may be either a TIR primitive, which is lowered to a primitive type, or a TIR Buffer, which is lowered to a DLTensor* argument and appropriate unpacking code. However, these two types are represented as an Array<tir::Var> and a Map<tir::Var, tir::Buffer>, which together represent a Array<Variant<tir::Var, tir::Buffer>>. The separate data structures must be kept in sync whenever modified, such as when removing a parameter.

  • Workaround: Use std::variant. For example, the tvm::tir::IdentifyMemCpyImpl utility function returns a std::variant with the result or an error message. However, this is only suitable for use within a C++ implementation, and requires a wrapper in order to expose it to the FFI.

This commit introduces a new container, `Variant`, which is analogous to
the `std::variant` introduced in C++17, the `enum` in Rust, or a
tagged union in C.  The `Variant` class is templated over the types that
it may contain (e.g. `Variant<String, Expr>`), where each type is a
distinct option that can be stored within the container.

`Variant` is implemented as a subclass of `ObjectRef` with no additional
data members, similar to the implementation of `Optional<T>`.  It can be
constructed from any of its contained types, and the contents can be
inspected using the usual `my_object.as<T>()` and
`Downcast<T>(my_object)` methods.  This is intended to allow for drop-in
replacement of `ObjectRef` with `Variant<Type1, Type2, ...>` in places
that previously used a common base class.

To ensure that each variant can be uniquely retrieved, no type
stored within the variant may inherit from any other type within the
variant.  This condition is checked at compile-time, with a
`static_assert` explaining the limitation.  This condition is necessary
to mimic the semantics of `std::variant`, whose active member depends on
the compile-time type of an object.  Without this condition, the
expression `Variant<PrimExpr, tir::Var> variant = PrimExpr(...)` could
populate either of the variants depending on the run-time type of an
object.  Because the `Variant` class is primarily intended for use when
two types do not already inherit from each other, this limitation is not
expected to limit its utility.

There are several locations within the TVM codebase where this pattern
may be useful, and which are currently worked around various
strategies.  (This PR does not alter any existing implementations,
instead introducing the `Variant` container that can be used in
subsequent PRs, if desired.)

* Workaround: Store a common base class.  For example, the type of
  `relax::TensorStructInfoNode::shape` is `Optional<Expr>`, with a
  comment stating that it should be only `NullOpt`, `ShapeExpr`, or
  `Var`.  However, these restrictions are not checked by the compiler,
  and a developer could erroneously provide a different type.  By
  expressing the type as as `Optional<Variant<Var,ShapeExpr>>`, these
  errors could be automatically caught.

* Workaround: Use additional data structures.  For example, a
  `PrimFunc` parameter may be either a TIR primitive, which is lowered
  to a primitive type, or a TIR Buffer, which is lowered to a
  `DLTensor*` argument and appropriate unpacking code.  However, these
  two types are represented as an `Array<tir::Var>` and a
  `Map<tir::Var, tir::Buffer>`, which together represent a
  `Array<Variant<tir::Var, tir::Buffer>>`.  The separate data
  structures must be kept in sync whenever modified, such as when
  removing a parameter.

* Workaround: Use `std::variant`.  For example, the
  `tvm::tir::IdentifyMemCpyImpl` utility function returns a
  `std::variant` with the result or an error message.  However, this
  is only suitable for use within a C++ implementation, and requires a
  wrapper in order to expose it to the FFI.
Copy link
Contributor

@csullivan csullivan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Comment on lines +858 to +876
TEST(Variant, Construct) {
Variant<PrimExpr, String> variant;
variant = PrimExpr(1);
ICHECK(variant.as<PrimExpr>());
ICHECK(!variant.as<String>());

variant = String("hello");
ICHECK(variant.as<String>());
ICHECK(!variant.as<PrimExpr>());
}

TEST(Variant, InvalidTypeThrowsError) {
auto expected_to_throw = []() {
ObjectPtr<Object> node = make_object<Object>();
Variant<PrimExpr, String> variant(node);
};

EXPECT_THROW(expected_to_throw(), InternalError);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A rather small set of tests, albeit for a fairly small API surface as compared to Array and Map. Are there other tests we could add? Maybe check assignment?

TEST(Variant, Assignment) {
  Variant<PrimExpr, String> variant;
  Variant<PrimExpr, String> variant2 = String("foo");
  variant = PrimExpr(1);
  variant2 = variant;

  ICHECK(variant2.as<PrimExpr>());
  # check the value of variant2
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I made the API surface as small as possible, but there were additional tests that should be included. I've added tests to validate that reference equality is preserved across Variant assignments, and that the values are correctly preserved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Lunderberg!

@csullivan
Copy link
Contributor

cc @junrushao

Comment on lines +858 to +876
TEST(Variant, Construct) {
Variant<PrimExpr, String> variant;
variant = PrimExpr(1);
ICHECK(variant.as<PrimExpr>());
ICHECK(!variant.as<String>());

variant = String("hello");
ICHECK(variant.as<String>());
ICHECK(!variant.as<PrimExpr>());
}

TEST(Variant, InvalidTypeThrowsError) {
auto expected_to_throw = []() {
ObjectPtr<Object> node = make_object<Object>();
Variant<PrimExpr, String> variant(node);
};

EXPECT_THROW(expected_to_throw(), InternalError);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Lunderberg!

@csullivan csullivan merged commit 24847c5 into apache:main Sep 13, 2023
18 checks passed
@Lunderberg Lunderberg deleted the variant_container branch September 14, 2023 22:01
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Jun 14, 2024
Prior to the implementation of `Variant<...>` in
apache#15672, functions that were
polymorphic over an argument type would typically accept an
`ObjectRef` argument, then downcast to an allowed type.  This delays
the catching of an error, and can accidentally omit automatic
conversions applied by the FFI.

This commit updates several locations using this pattern to instead
accept a `Variant`, templated over the allowed types.  This enables
C++ type checking for C++ callers, standardizes the type-checking in
the FFI for non-C++ callers, and ensures that FFI type conversions are
uniformly applied.
Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Jun 14, 2024
Prior to the implementation of `Variant<...>` in
apache#15672, functions that were
polymorphic over an argument type would typically accept an
`ObjectRef` argument, then downcast to an allowed type.  This delays
the catching of an error, and can accidentally omit automatic
conversions applied by the FFI.

This commit updates several locations using this pattern to instead
accept a `Variant`, templated over the allowed types.  This enables
C++ type checking for C++ callers, standardizes the type-checking in
the FFI for non-C++ callers, and ensures that FFI type conversions are
uniformly applied.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants