Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] tir.transform.StorageFlatten refactor #9091

Merged
merged 21 commits into from
Oct 1, 2021

Conversation

Lunderberg
Copy link
Contributor

This started after noticing that StorageFlatten incorrectly handled BufferLoad/BufferStore nodes that pointed to a buffer defined in an attr::buffer_bind_scope annotation. Rather than adding more logic into the existing StorageFlattener mutator, I split up the existing behavior into multiple independent mutators.

This PR includes a series of commits, each of which refactors one of the behaviors out of the StorageFlattener class and into a separate class. While all of the transforms are called sequentially in the tir.transform.StorageFlatten to maintain the same overall behavior, each transform results in a valid TIR tree.

  • BufferShapeLegalize, which rewrites Buffer nodes to have sizes that match the BufferRealize node in which they are defined.
  • BufferStrideLegalize, which rewrites the strides of Buffer nodes that are annotated with attr::dim_align.
  • ThreadScopePropagate, which defines the allocation scope of Buffer nodes based on the thread iter in which they are declared, if no allocation scope was already defined.
  • BufferBindUnwrapper, which rewrites access into Buffer objects that are defined by attr::buffer_bind_scope. Refactoring this behavior into a separate mutator was my original goal, in order to resolve the issue of BufferLoad/BufferStore nodes that point to bound buffers, but doing so required the previous three behaviors to also be refactored into separate mutators.
  • StorageFlattener, which contains all remaining behavior from the original StorageFlattener, and outputs the final Allocate/Store/Load nodes.

This refactor will also help in the future, when introducing layout transformations.

@Lunderberg
Copy link
Contributor Author

@csullivan

@tmoreau89
Copy link
Contributor

CC @kparzysz-quic

@junrushao
Copy link
Member

CC: @vinx13

Copy link
Contributor

@csullivan csullivan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great @Lunderberg, thank you! I'm part way through review and will continue more this afternoon. Very nice to see the functionality of storage flatten broken out into their logical units, and also dependence on AttrStmts removed early.

Comment on lines 74 to 94
ICHECK_EQ(op->buffer->shape.size(), op->bounds.size())
<< "External buffer realize has mismatched dimension";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider making this a CHECK as external buffers can be provided by the user and failing this check could indicate a use error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, and changed.

}

private:
Stmt HandleBufferBindScope(const AttrStmtNode* op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment describing that the buffer bind scope buffer attributes are updated according to the legalized buffer. Similarly add comments in the other passes which have BufferBindScope handler methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, and comments have been added.


// Keeping this to have matched behavior to previous version.
// There are many parts of the codebase that assume that a strided
// array cannot be compact.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: An example of one such case that makes this assumption could be useful for context.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, and added.

Copy link
Contributor

@csullivan csullivan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Finishing up the review,

// but this binded region is a subregion of
// a matrix(tensor), which means it requires strides.
//
// We do support a few relaxed case, such as bindingx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo,

Suggested change
// We do support a few relaxed case, such as bindingx
// We do support a few relaxed case, such as binding a

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo fixed.

Comment on lines 946 to 1011
PrimExpr ElemOffset() const {
ICHECK(remap);

Buffer copy = remap->target;
{
Array<PrimExpr> shape;
for (auto r : bounds) {
shape.push_back(r->extent);
}
copy.CopyOnWrite()->shape = std::move(shape);
}

Buffer target_slice = copy.MakeSlice(remap->begins, remap->extents);
if (buffer->strides.size() == 0) {
ICHECK_EQ(target_slice->strides.size(), 0U)
<< "Trying to bind compact buffer to strided one strides=" << target_slice->strides;
} else {
target_slice = target_slice.MakeStrideView();
}

return copy->ElemOffset(remap->begins);
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not seeing struct BufferEntry::ElemOffset used anywhere. Consider removing or refactoring to use the method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And removed. It was part of an earlier (and broken) implementation.

Comment on lines 1301 to 1406
// The specific tensor data layout is not determined before
// StorageFlatten pass. We use buffer_bind_scope
// to specify before hand we want to bind a subregion
// of tensor to a symbolic buffer, which get used in extern.
//
// Example:
//
// realize A in range [i*4, extent=10) {
// bind Ab to A in [i*4+1, extent=4) {
// call_func(Ab.ptr, Ab.shape[0])
// }
// }
//
// After StorageFlatten
//
// alloc A[10]
// call(A + 1, 4)
//
// Buffer is a protocol to declare specific
// data layout and shape we expect.
// So this function need to check:
// - If the bind range is within the realize range
// - If we can match the requirement of buffer
// - Remap variables such as Ab.ptr to the actual value.
//
// Here are a few possible failure cases:
// - Buffer is declared to have constant shape,
// but we try to bind it to a different one.
// - Buffer is declared to be compact(no strides)
// but this binded region is a subregion of
// a matrix(tensor), which means it requires strides.
//
// We do support a few relaxed case, such as bindingx
// region with shape [1, 1, n, m] to buffer with shape [n, m]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These docs are duplicated from those on the BufferBindUnwrapper. Consider removing or updating.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, and updated to have shorter documentation in BufferBindUnwrapper, while StorageFlatten maintains the full documentation.


fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body));

fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In reading through the refactor, it occurs to me that the passes prior to BufferBindUnwrapper could be simpler if the buffer_bind_scope was unwrapped earlier, e.g. prior to ThreadScopePropagate or perhaps first of all. Then each pass would not need to special case the handling done in the variants of HandleBufferBindScope.

Is there something that makes it difficult to apply BufferBindUnwrapper earlier?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that I'd want to unwrap the binds earlier, if I could, to prevent the amount of rewriting needed to pass the updated buffers along. The main issue I ran into was for IR definitions that directly reference buf.elem_offset (example). In order to determine the offset of the bufffer view relative to the data pointer of the parent buffer, the shape and strides of the parent buffer need to be determined first.

I have two ideas for making the implementation be cleaner and more readable. One is to change how data are packed in an AttrStmtNode for buffer_bind_scope, to use the MatchBufferRegion class. The other is to extend StmtExprMutator to act on BufferNode, so that the buffer replacements only need to be done in a single location for each pass through. As it is, rewriting the BufferStoreNode, BufferLoadNode, AttrStmtNode, and CallNode must be done each time the buffer gets modified, even if it's just to use the modified Buffer object.

Lunderberg added a commit to Lunderberg/tvm that referenced this pull request Sep 28, 2021
Parametrized it to get more detailed information while debugging
failures in apache#9091, but isn't
semantically part of that PR.
// behavior from ArgBinder::BindBuffer.
size_t diff = entry.realized_begins.size() - op->indices.size();
for (size_t i = 0; i < diff; i++) {
new_indices.push_back(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming this is for matching cases like [1, 1, n, m], do we need to check that the leading axes are indeed extent=1?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me ask a more general question, is it possible to expand or squeeze all shapes and then do exact matching as before? Noticing a fair amount of special casing in this commit for handling the extra unit dimensions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the specific case, there is a check as part of the call to ArgBinder::BindBuffer. I'll add a comment to indicate that.

For the general question, that was the intent in BufferShapeLegalize, so that afterwards the buffers all have a single well-defined shape. It looks like I had missed one case where BufferBindUnwrapper changed the number of dimensions when binding to an external buffer, but I've added another commit to this PR to pull that behavior into BufferShapeLegalize instead.

masahi pushed a commit that referenced this pull request Sep 29, 2021
Parametrized it to get more detailed information while debugging
failures in #9091, but isn't
semantically part of that PR.
Allowed a compact DLTensor to bind to a Buffer object that defines
strides, if the strides defined correspond to a compact layout.
Previously, StorageFlattener would determine the shape of a physical
buffer based on the extents of the BufferRealizeNode.  Pulled these
out into a separate BufferShapeLegalize pass.  After this pass, all
buffers have a shape that matches the buffer realization extents.
Previously, StorageFlattener would handle any attr::dim_align
annotations.  Now, this is pulled out into a separate
BufferStrideLegalize pass.
Previously, StorageFlattener would use the scope in IterVar to assign
a scope to allocated buffers, where not otherwise defined.  This has
been pulled out into a separate ThreadScopePropagate pass.
Previously, StorageFlattener would look for `attr::buffer_bind_scope`
to determine if a Buffer object is a view into another buffer, and
would apply that mapping while making the Allocate/Store/Load nodes.
Now, the mapping of buffer binds is pulled out into a separate
BufferStrideUnwrapper pass.

This also resolves an issue in which BufferLoad/BufferStore nodes that
refer to a Buffer defined through `attr::buffer_bind_scope` would
generate Load/Store nodes that point to the linked buffer, rather than
the actual buffer.
Even after BufferShapeLegalize, rank-zero tensors may have an empty
shape.
Original refactoring requiring that a bufferview have no explicit
striding, and instead take the striding from the buffer that it is
viewing.  Modified to allow bufferview to specify striding, so long as
it is consistent with the viewed buffer's striding.  This reproduces
the behavior of StorageFlatten before the refactoring.
AttrStmtNodes that contain rewritten Buffers need to be rewritten as
well.
The earlier stage of the refactor left a buffer's storage scope
undefined if it's scope was not determined by the IterVar of a loop
containing its allocation.  Now, these are explicitly set to
StorageScope::kGlobal, to match the previous behavior of
StorageFlatten.
Maintains earlier behavior of StorageFlatten, which allows buffer
views to be mapped to higher dimension buffers, if the view extent is
1 in each extra dimension.
Previously, BufferBindUnwrapper passed fuzzy_match=true to
ArgBinder::BindBuffer, which could change the number of dimensions.
Now, all buffer dimensions should be updated prior to
BufferBindUnwrapper, and it is an error to have mismatched dimensions
in BufferBindUnwrapper.
@Lunderberg
Copy link
Contributor Author

Latest round of changes, added another pass to remove assert statements that can be statically validated. These are placed by ArgBinder::Bind if it can't verify a constraint at the time when it binds a variable. If later variable substitutions allow the constraint to be statically verified, they can still remain in the final generated code. These didn't appear prior to the refactor, because StorageFlatten made a single substitution, whereas the refactor does so in multiple passes.

ArgBinder::BindBuffer inserts these assert statements if they are not
verifiable at the time of substitution.  Previously, with one giant
substitution, the assertions were verifiable at that time.  After the
refactor, with substitutions done in multiple stages for
shape/stride/buffer_bind_scope, we need to clean up any assertions
that are verifiable after all substitutions have occurred.
- Removed StorageFlattener::BufferEntry::RelIndex, behavior already
  handled by BufferShapeLegalize.

- Improved comments and error messages.

- Extracted duplicate behavior in BufferLoad/BufferStore handling in
  BufferShapeLegalize.
A true Assert statement can be removed, but a false Assert statement
requires CFA to give as a compile-time error.  Since we only need the
removal of true assert statements, skipping the CFA this time.
Copy link
Contributor

@csullivan csullivan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks @Lunderberg!

Copy link
Contributor

@tmoreau89 tmoreau89 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @Lunderberg LGTM

@tmoreau89 tmoreau89 merged commit 659f3b7 into apache:main Oct 1, 2021
@tmoreau89
Copy link
Contributor

Thanks @csullivan and @Lunderberg the PR has been merged.

@Lunderberg Lunderberg deleted the storage_flatten_refactor branch October 1, 2021 01:00
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 7, 2022
Parametrized it to get more detailed information while debugging
failures in apache#9091, but isn't
semantically part of that PR.
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 7, 2022
* [TE] Improved flexibility of ArgBinder::BindDLTensor

Allowed a compact DLTensor to bind to a Buffer object that defines
strides, if the strides defined correspond to a compact layout.

* [TIR] Exposed ElemOffset as a member function of BufferNode.

* [TE] Pulled shape determination out of StorageFlattener

Previously, StorageFlattener would determine the shape of a physical
buffer based on the extents of the BufferRealizeNode.  Pulled these
out into a separate BufferShapeLegalize pass.  After this pass, all
buffers have a shape that matches the buffer realization extents.

* [TE] Refactor stride calculation out of StorageFlattener

Previously, StorageFlattener would handle any attr::dim_align
annotations.  Now, this is pulled out into a separate
BufferStrideLegalize pass.

* [TE] Refactor thread scope propagation out of StorageFlattener.

Previously, StorageFlattener would use the scope in IterVar to assign
a scope to allocated buffers, where not otherwise defined.  This has
been pulled out into a separate ThreadScopePropagate pass.

* [TE] Refactor buffer bind mapping out of StorageFlattener.

Previously, StorageFlattener would look for `attr::buffer_bind_scope`
to determine if a Buffer object is a view into another buffer, and
would apply that mapping while making the Allocate/Store/Load nodes.
Now, the mapping of buffer binds is pulled out into a separate
BufferStrideUnwrapper pass.

This also resolves an issue in which BufferLoad/BufferStore nodes that
refer to a Buffer defined through `attr::buffer_bind_scope` would
generate Load/Store nodes that point to the linked buffer, rather than
the actual buffer.

* [TIR] Removed checks on buffer->shape.size()

Even after BufferShapeLegalize, rank-zero tensors may have an empty
shape.

* [TIR] Relaxed check on a bufferview's striding.

Original refactoring requiring that a bufferview have no explicit
striding, and instead take the striding from the buffer that it is
viewing.  Modified to allow bufferview to specify striding, so long as
it is consistent with the viewed buffer's striding.  This reproduces
the behavior of StorageFlatten before the refactoring.

* [TIR] Fixed StorageFlatten test for shape_legalize.

AttrStmtNodes that contain rewritten Buffers need to be rewritten as
well.

* [TIR] Assigned storage scope

The earlier stage of the refactor left a buffer's storage scope
undefined if it's scope was not determined by the IterVar of a loop
containing its allocation.  Now, these are explicitly set to
StorageScope::kGlobal, to match the previous behavior of
StorageFlatten.

* Updated ICHECK_EQ to CHECK_EQ for a test that depends on user-provided
data.

* Added comments in storage_flatten.cc, indicating why buffer_bind_scope
needs special handling.

* Updated comment with a few examples of where compact buffers are
assumed to have no strides defined.

* Updated following @csullivan's comments.

* Added fuzzy mapping to the BufferShapeLegalize.

Maintains earlier behavior of StorageFlatten, which allows buffer
views to be mapped to higher dimension buffers, if the view extent is
1 in each extra dimension.

* Updated BufferShapeLegalize, asserts need to be inside the buffer_bind_scope.

* Pulled all shape-dependent behavior into BufferShapeLegalize.

Previously, BufferBindUnwrapper passed fuzzy_match=true to
ArgBinder::BindBuffer, which could change the number of dimensions.
Now, all buffer dimensions should be updated prior to
BufferBindUnwrapper, and it is an error to have mismatched dimensions
in BufferBindUnwrapper.

* Added another pass to remove verifiable assert statements.

ArgBinder::BindBuffer inserts these assert statements if they are not
verifiable at the time of substitution.  Previously, with one giant
substitution, the assertions were verifiable at that time.  After the
refactor, with substitutions done in multiple stages for
shape/stride/buffer_bind_scope, we need to clean up any assertions
that are verifiable after all substitutions have occurred.

* Minor cleanup

- Removed StorageFlattener::BufferEntry::RelIndex, behavior already
  handled by BufferShapeLegalize.

- Improved comments and error messages.

- Extracted duplicate behavior in BufferLoad/BufferStore handling in
  BufferShapeLegalize.

* Updated to handle BufferRealizeNode with no defined bounds.

* Updated to be less aggressive when checking AssertStmt

A true Assert statement can be removed, but a false Assert statement
requires CFA to give as a compile-time error.  Since we only need the
removal of true assert statements, skipping the CFA this time.
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 13, 2022
Parametrized it to get more detailed information while debugging
failures in apache#9091, but isn't
semantically part of that PR.
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 13, 2022
* [TE] Improved flexibility of ArgBinder::BindDLTensor

Allowed a compact DLTensor to bind to a Buffer object that defines
strides, if the strides defined correspond to a compact layout.

* [TIR] Exposed ElemOffset as a member function of BufferNode.

* [TE] Pulled shape determination out of StorageFlattener

Previously, StorageFlattener would determine the shape of a physical
buffer based on the extents of the BufferRealizeNode.  Pulled these
out into a separate BufferShapeLegalize pass.  After this pass, all
buffers have a shape that matches the buffer realization extents.

* [TE] Refactor stride calculation out of StorageFlattener

Previously, StorageFlattener would handle any attr::dim_align
annotations.  Now, this is pulled out into a separate
BufferStrideLegalize pass.

* [TE] Refactor thread scope propagation out of StorageFlattener.

Previously, StorageFlattener would use the scope in IterVar to assign
a scope to allocated buffers, where not otherwise defined.  This has
been pulled out into a separate ThreadScopePropagate pass.

* [TE] Refactor buffer bind mapping out of StorageFlattener.

Previously, StorageFlattener would look for `attr::buffer_bind_scope`
to determine if a Buffer object is a view into another buffer, and
would apply that mapping while making the Allocate/Store/Load nodes.
Now, the mapping of buffer binds is pulled out into a separate
BufferStrideUnwrapper pass.

This also resolves an issue in which BufferLoad/BufferStore nodes that
refer to a Buffer defined through `attr::buffer_bind_scope` would
generate Load/Store nodes that point to the linked buffer, rather than
the actual buffer.

* [TIR] Removed checks on buffer->shape.size()

Even after BufferShapeLegalize, rank-zero tensors may have an empty
shape.

* [TIR] Relaxed check on a bufferview's striding.

Original refactoring requiring that a bufferview have no explicit
striding, and instead take the striding from the buffer that it is
viewing.  Modified to allow bufferview to specify striding, so long as
it is consistent with the viewed buffer's striding.  This reproduces
the behavior of StorageFlatten before the refactoring.

* [TIR] Fixed StorageFlatten test for shape_legalize.

AttrStmtNodes that contain rewritten Buffers need to be rewritten as
well.

* [TIR] Assigned storage scope

The earlier stage of the refactor left a buffer's storage scope
undefined if it's scope was not determined by the IterVar of a loop
containing its allocation.  Now, these are explicitly set to
StorageScope::kGlobal, to match the previous behavior of
StorageFlatten.

* Updated ICHECK_EQ to CHECK_EQ for a test that depends on user-provided
data.

* Added comments in storage_flatten.cc, indicating why buffer_bind_scope
needs special handling.

* Updated comment with a few examples of where compact buffers are
assumed to have no strides defined.

* Updated following @csullivan's comments.

* Added fuzzy mapping to the BufferShapeLegalize.

Maintains earlier behavior of StorageFlatten, which allows buffer
views to be mapped to higher dimension buffers, if the view extent is
1 in each extra dimension.

* Updated BufferShapeLegalize, asserts need to be inside the buffer_bind_scope.

* Pulled all shape-dependent behavior into BufferShapeLegalize.

Previously, BufferBindUnwrapper passed fuzzy_match=true to
ArgBinder::BindBuffer, which could change the number of dimensions.
Now, all buffer dimensions should be updated prior to
BufferBindUnwrapper, and it is an error to have mismatched dimensions
in BufferBindUnwrapper.

* Added another pass to remove verifiable assert statements.

ArgBinder::BindBuffer inserts these assert statements if they are not
verifiable at the time of substitution.  Previously, with one giant
substitution, the assertions were verifiable at that time.  After the
refactor, with substitutions done in multiple stages for
shape/stride/buffer_bind_scope, we need to clean up any assertions
that are verifiable after all substitutions have occurred.

* Minor cleanup

- Removed StorageFlattener::BufferEntry::RelIndex, behavior already
  handled by BufferShapeLegalize.

- Improved comments and error messages.

- Extracted duplicate behavior in BufferLoad/BufferStore handling in
  BufferShapeLegalize.

* Updated to handle BufferRealizeNode with no defined bounds.

* Updated to be less aggressive when checking AssertStmt

A true Assert statement can be removed, but a false Assert statement
requires CFA to give as a compile-time error.  Since we only need the
removal of true assert statements, skipping the CFA this time.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants