-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[IR] Implemented Variant<...> container #15672
Conversation
This commit introduces a new container, `Variant`, which is analogous to the `std::variant` introduced in C++17, the `enum` in Rust, or a tagged union in C. The `Variant` class is templated over the types that it may contain (e.g. `Variant<String, Expr>`), where each type is a distinct option that can be stored within the container. `Variant` is implemented as a subclass of `ObjectRef` with no additional data members, similar to the implementation of `Optional<T>`. It can be constructed from any of its contained types, and the contents can be inspected using the usual `my_object.as<T>()` and `Downcast<T>(my_object)` methods. This is intended to allow for drop-in replacement of `ObjectRef` with `Variant<Type1, Type2, ...>` in places that previously used a common base class. To ensure that each variant can be uniquely retrieved, no type stored within the variant may inherit from any other type within the variant. This condition is checked at compile-time, with a `static_assert` explaining the limitation. This condition is necessary to mimic the semantics of `std::variant`, whose active member depends on the compile-time type of an object. Without this condition, the expression `Variant<PrimExpr, tir::Var> variant = PrimExpr(...)` could populate either of the variants depending on the run-time type of an object. Because the `Variant` class is primarily intended for use when two types do not already inherit from each other, this limitation is not expected to limit its utility. There are several locations within the TVM codebase where this pattern may be useful, and which are currently worked around various strategies. (This PR does not alter any existing implementations, instead introducing the `Variant` container that can be used in subsequent PRs, if desired.) * Workaround: Store a common base class. For example, the type of `relax::TensorStructInfoNode::shape` is `Optional<Expr>`, with a comment stating that it should be only `NullOpt`, `ShapeExpr`, or `Var`. However, these restrictions are not checked by the compiler, and a developer could erroneously provide a different type. By expressing the type as as `Optional<Variant<Var,ShapeExpr>>`, these errors could be automatically caught. * Workaround: Use additional data structures. For example, a `PrimFunc` parameter may be either a TIR primitive, which is lowered to a primitive type, or a TIR Buffer, which is lowered to a `DLTensor*` argument and appropriate unpacking code. However, these two types are represented as an `Array<tir::Var>` and a `Map<tir::Var, tir::Buffer>`, which together represent a `Array<Variant<tir::Var, tir::Buffer>>`. The separate data structures must be kept in sync whenever modified, such as when removing a parameter. * Workaround: Use `std::variant`. For example, the `tvm::tir::IdentifyMemCpyImpl` utility function returns a `std::variant` with the result or an error message. However, this is only suitable for use within a C++ implementation, and requires a wrapper in order to expose it to the FFI.
387eadc
to
422eafd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
TEST(Variant, Construct) { | ||
Variant<PrimExpr, String> variant; | ||
variant = PrimExpr(1); | ||
ICHECK(variant.as<PrimExpr>()); | ||
ICHECK(!variant.as<String>()); | ||
|
||
variant = String("hello"); | ||
ICHECK(variant.as<String>()); | ||
ICHECK(!variant.as<PrimExpr>()); | ||
} | ||
|
||
TEST(Variant, InvalidTypeThrowsError) { | ||
auto expected_to_throw = []() { | ||
ObjectPtr<Object> node = make_object<Object>(); | ||
Variant<PrimExpr, String> variant(node); | ||
}; | ||
|
||
EXPECT_THROW(expected_to_throw(), InternalError); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A rather small set of tests, albeit for a fairly small API surface as compared to Array and Map. Are there other tests we could add? Maybe check assignment?
TEST(Variant, Assignment) {
Variant<PrimExpr, String> variant;
Variant<PrimExpr, String> variant2 = String("foo");
variant = PrimExpr(1);
variant2 = variant;
ICHECK(variant2.as<PrimExpr>());
# check the value of variant2
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I made the API surface as small as possible, but there were additional tests that should be included. I've added tests to validate that reference equality is preserved across Variant
assignments, and that the values are correctly preserved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Lunderberg!
cc @junrushao |
TEST(Variant, Construct) { | ||
Variant<PrimExpr, String> variant; | ||
variant = PrimExpr(1); | ||
ICHECK(variant.as<PrimExpr>()); | ||
ICHECK(!variant.as<String>()); | ||
|
||
variant = String("hello"); | ||
ICHECK(variant.as<String>()); | ||
ICHECK(!variant.as<PrimExpr>()); | ||
} | ||
|
||
TEST(Variant, InvalidTypeThrowsError) { | ||
auto expected_to_throw = []() { | ||
ObjectPtr<Object> node = make_object<Object>(); | ||
Variant<PrimExpr, String> variant(node); | ||
}; | ||
|
||
EXPECT_THROW(expected_to_throw(), InternalError); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @Lunderberg!
Prior to the implementation of `Variant<...>` in apache#15672, functions that were polymorphic over an argument type would typically accept an `ObjectRef` argument, then downcast to an allowed type. This delays the catching of an error, and can accidentally omit automatic conversions applied by the FFI. This commit updates several locations using this pattern to instead accept a `Variant`, templated over the allowed types. This enables C++ type checking for C++ callers, standardizes the type-checking in the FFI for non-C++ callers, and ensures that FFI type conversions are uniformly applied.
Prior to the implementation of `Variant<...>` in apache#15672, functions that were polymorphic over an argument type would typically accept an `ObjectRef` argument, then downcast to an allowed type. This delays the catching of an error, and can accidentally omit automatic conversions applied by the FFI. This commit updates several locations using this pattern to instead accept a `Variant`, templated over the allowed types. This enables C++ type checking for C++ callers, standardizes the type-checking in the FFI for non-C++ callers, and ensures that FFI type conversions are uniformly applied.
This commit introduces a new container,
Variant
, which is analogous to thestd::variant
introduced in C++17, theenum
in Rust, or a tagged union in C. TheVariant
class is templated over the types that it may contain (e.g.Variant<String, Expr>
), where each type is a distinct option that can be stored within the container.Variant
is implemented as a subclass ofObjectRef
with no additional data members, similar to the implementation ofOptional<T>
. It can be constructed from any of its contained types, and the contents can be inspected using the usualmy_object.as<T>()
andDowncast<T>(my_object)
methods. This is intended to allow for drop-in replacement ofObjectRef
withVariant<Type1, Type2, ...>
in places that previously used a common base class.To ensure that each variant can be uniquely retrieved, no type stored within the variant may inherit from any other type within the variant. This condition is checked at compile-time, with a
static_assert
explaining the limitation. This condition is necessary to mimic the semantics ofstd::variant
, whose active member depends on the compile-time type of an object. Without this condition, the expressionVariant<PrimExpr, tir::Var> variant = PrimExpr(...)
could populate either of the variants depending on the run-time type of an object. Because theVariant
class is primarily intended for use when two types do not already inherit from each other, this limitation is not expected to limit its utility.There are several locations within the TVM codebase where this pattern may be useful, and which are currently worked around various strategies. (This PR does not alter any existing implementations, instead introducing the
Variant
container that can be used in subsequent PRs, if desired.)Workaround: Store a common base class. For example, the type of
relax::TensorStructInfoNode::shape
isOptional<Expr>
, with a comment stating that it should be onlyNullOpt
,ShapeExpr
, orVar
. However, these restrictions are not checked by the compiler, and a developer could erroneously provide a different type. By expressing the type as asOptional<Variant<Var,ShapeExpr>>
, these errors could be automatically caught.Workaround: Use additional data structures. For example, a
PrimFunc
parameter may be either a TIR primitive, which is lowered to a primitive type, or a TIR Buffer, which is lowered to aDLTensor*
argument and appropriate unpacking code. However, these two types are represented as anArray<tir::Var>
and aMap<tir::Var, tir::Buffer>
, which together represent aArray<Variant<tir::Var, tir::Buffer>>
. The separate data structures must be kept in sync whenever modified, such as when removing a parameter.Workaround: Use
std::variant
. For example, thetvm::tir::IdentifyMemCpyImpl
utility function returns astd::variant
with the result or an error message. However, this is only suitable for use within a C++ implementation, and requires a wrapper in order to expose it to the FFI.