Skip to content

Commit

Permalink
[Utility][Container] Support non-nullable types in Array::Map (#17094)
Browse files Browse the repository at this point in the history
[Container] Support non-nullable types in Array::Map

Prior to this commit, the `Array::Map` member function could only be
applied to nullable object types.  This was due to the internal use of
`U()` as the default value for initializing the output `ArrayNode`, where
`U` is the return type of the mapping function.  This default
constructor is only available for nullable types, and would result in
a compile-time failure for non-nullable types.

This commit replaces `U()` with `ObjectRef()` in `Array::Map`,
removing this limitation.  Since all items in the output array are
overwritten before returning to the calling scope, initializing the
output array with `ObjectRef()` does not violate type safety.
  • Loading branch information
Lunderberg committed Jun 18, 2024
1 parent a4f20f0 commit e520b9b
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions include/tvm/runtime/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,13 @@ class Array : public ObjectRef {
// consisting of any previous elements that had mapped to
// themselves (if any), and the element that didn't map to
// itself.
//
// We cannot use `U()` as the default object, as `U` may be
// a non-nullable type. Since the default `ObjectRef()`
// will be overwritten before returning, all objects will be
// of type `U` for the calling scope.
all_identical = false;
output = ArrayNode::CreateRepeated(arr->size(), U());
output = ArrayNode::CreateRepeated(arr->size(), ObjectRef());
output->InitRange(0, arr->begin(), it);
output->SetItem(it - arr->begin(), std::move(mapped));
it++;
Expand All @@ -843,7 +848,12 @@ class Array : public ObjectRef {
// compatible types isn't strictly necessary, as the first
// mapped.same_as(*it) would return false, but we might as well
// avoid it altogether.
output = ArrayNode::CreateRepeated(arr->size(), U());
//
// We cannot use `U()` as the default object, as `U` may be a
// non-nullable type. Since the default `ObjectRef()` will be
// overwritten before returning, all objects will be of type `U`
// for the calling scope.
output = ArrayNode::CreateRepeated(arr->size(), ObjectRef());
}

// Normal path for incompatible types, or post-copy path for
Expand Down

0 comments on commit e520b9b

Please sign in to comment.