-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR] tir.transform.StorageFlatten refactor #9091
Conversation
CC: @vinx13 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great @Lunderberg, thank you! I'm part way through review and will continue more this afternoon. Very nice to see the functionality of storage flatten broken out into their logical units, and also dependence on AttrStmts removed early.
ICHECK_EQ(op->buffer->shape.size(), op->bounds.size()) | ||
<< "External buffer realize has mismatched dimension"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider making this a CHECK as external buffers can be provided by the user and failing this check could indicate a use error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, and changed.
} | ||
|
||
private: | ||
Stmt HandleBufferBindScope(const AttrStmtNode* op) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment describing that the buffer bind scope buffer attributes are updated according to the legalized buffer. Similarly add comments in the other passes which have BufferBindScope handler methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, and comments have been added.
|
||
// Keeping this to have matched behavior to previous version. | ||
// There are many parts of the codebase that assume that a strided | ||
// array cannot be compact. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: An example of one such case that makes this assumption could be useful for context.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, and added.
d797be3
to
a92100d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finishing up the review,
// but this binded region is a subregion of | ||
// a matrix(tensor), which means it requires strides. | ||
// | ||
// We do support a few relaxed case, such as bindingx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo,
// We do support a few relaxed case, such as bindingx | |
// We do support a few relaxed case, such as binding a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo fixed.
PrimExpr ElemOffset() const { | ||
ICHECK(remap); | ||
|
||
Buffer copy = remap->target; | ||
{ | ||
Array<PrimExpr> shape; | ||
for (auto r : bounds) { | ||
shape.push_back(r->extent); | ||
} | ||
copy.CopyOnWrite()->shape = std::move(shape); | ||
} | ||
|
||
Buffer target_slice = copy.MakeSlice(remap->begins, remap->extents); | ||
if (buffer->strides.size() == 0) { | ||
ICHECK_EQ(target_slice->strides.size(), 0U) | ||
<< "Trying to bind compact buffer to strided one strides=" << target_slice->strides; | ||
} else { | ||
target_slice = target_slice.MakeStrideView(); | ||
} | ||
|
||
return copy->ElemOffset(remap->begins); | ||
} | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not seeing struct BufferEntry::ElemOffset
used anywhere. Consider removing or refactoring to use the method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And removed. It was part of an earlier (and broken) implementation.
// The specific tensor data layout is not determined before | ||
// StorageFlatten pass. We use buffer_bind_scope | ||
// to specify before hand we want to bind a subregion | ||
// of tensor to a symbolic buffer, which get used in extern. | ||
// | ||
// Example: | ||
// | ||
// realize A in range [i*4, extent=10) { | ||
// bind Ab to A in [i*4+1, extent=4) { | ||
// call_func(Ab.ptr, Ab.shape[0]) | ||
// } | ||
// } | ||
// | ||
// After StorageFlatten | ||
// | ||
// alloc A[10] | ||
// call(A + 1, 4) | ||
// | ||
// Buffer is a protocol to declare specific | ||
// data layout and shape we expect. | ||
// So this function need to check: | ||
// - If the bind range is within the realize range | ||
// - If we can match the requirement of buffer | ||
// - Remap variables such as Ab.ptr to the actual value. | ||
// | ||
// Here are a few possible failure cases: | ||
// - Buffer is declared to have constant shape, | ||
// but we try to bind it to a different one. | ||
// - Buffer is declared to be compact(no strides) | ||
// but this binded region is a subregion of | ||
// a matrix(tensor), which means it requires strides. | ||
// | ||
// We do support a few relaxed case, such as bindingx | ||
// region with shape [1, 1, n, m] to buffer with shape [n, m] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These docs are duplicated from those on the BufferBindUnwrapper. Consider removing or updating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, and updated to have shorter documentation in BufferBindUnwrapper, while StorageFlatten maintains the full documentation.
|
||
fptr->body = ThreadScopePropagate(fptr->buffer_map)(std::move(fptr->body)); | ||
|
||
fptr->body = BufferBindUnwrapper(fptr->buffer_map, &bound_analyzer)(std::move(fptr->body)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In reading through the refactor, it occurs to me that the passes prior to BufferBindUnwrapper could be simpler if the buffer_bind_scope
was unwrapped earlier, e.g. prior to ThreadScopePropagate or perhaps first of all. Then each pass would not need to special case the handling done in the variants of HandleBufferBindScope
.
Is there something that makes it difficult to apply BufferBindUnwrapper
earlier?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that I'd want to unwrap the binds earlier, if I could, to prevent the amount of rewriting needed to pass the updated buffers along. The main issue I ran into was for IR definitions that directly reference buf.elem_offset
(example). In order to determine the offset of the bufffer view relative to the data
pointer of the parent buffer, the shape and strides of the parent buffer need to be determined first.
I have two ideas for making the implementation be cleaner and more readable. One is to change how data are packed in an AttrStmtNode
for buffer_bind_scope
, to use the MatchBufferRegion
class. The other is to extend StmtExprMutator
to act on BufferNode
, so that the buffer replacements only need to be done in a single location for each pass through. As it is, rewriting the BufferStoreNode
, BufferLoadNode
, AttrStmtNode
, and CallNode
must be done each time the buffer gets modified, even if it's just to use the modified Buffer object.
Parametrized it to get more detailed information while debugging failures in apache#9091, but isn't semantically part of that PR.
// behavior from ArgBinder::BindBuffer. | ||
size_t diff = entry.realized_begins.size() - op->indices.size(); | ||
for (size_t i = 0; i < diff; i++) { | ||
new_indices.push_back(0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Assuming this is for matching cases like [1, 1, n, m], do we need to check that the leading axes are indeed extent=1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me ask a more general question, is it possible to expand or squeeze all shapes and then do exact matching as before? Noticing a fair amount of special casing in this commit for handling the extra unit dimensions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the specific case, there is a check as part of the call to ArgBinder::BindBuffer
. I'll add a comment to indicate that.
For the general question, that was the intent in BufferShapeLegalize
, so that afterwards the buffers all have a single well-defined shape. It looks like I had missed one case where BufferBindUnwrapper
changed the number of dimensions when binding to an external buffer, but I've added another commit to this PR to pull that behavior into BufferShapeLegalize
instead.
Parametrized it to get more detailed information while debugging failures in #9091, but isn't semantically part of that PR.
Allowed a compact DLTensor to bind to a Buffer object that defines strides, if the strides defined correspond to a compact layout.
Previously, StorageFlattener would determine the shape of a physical buffer based on the extents of the BufferRealizeNode. Pulled these out into a separate BufferShapeLegalize pass. After this pass, all buffers have a shape that matches the buffer realization extents.
Previously, StorageFlattener would handle any attr::dim_align annotations. Now, this is pulled out into a separate BufferStrideLegalize pass.
Previously, StorageFlattener would use the scope in IterVar to assign a scope to allocated buffers, where not otherwise defined. This has been pulled out into a separate ThreadScopePropagate pass.
Previously, StorageFlattener would look for `attr::buffer_bind_scope` to determine if a Buffer object is a view into another buffer, and would apply that mapping while making the Allocate/Store/Load nodes. Now, the mapping of buffer binds is pulled out into a separate BufferStrideUnwrapper pass. This also resolves an issue in which BufferLoad/BufferStore nodes that refer to a Buffer defined through `attr::buffer_bind_scope` would generate Load/Store nodes that point to the linked buffer, rather than the actual buffer.
Even after BufferShapeLegalize, rank-zero tensors may have an empty shape.
Original refactoring requiring that a bufferview have no explicit striding, and instead take the striding from the buffer that it is viewing. Modified to allow bufferview to specify striding, so long as it is consistent with the viewed buffer's striding. This reproduces the behavior of StorageFlatten before the refactoring.
AttrStmtNodes that contain rewritten Buffers need to be rewritten as well.
The earlier stage of the refactor left a buffer's storage scope undefined if it's scope was not determined by the IterVar of a loop containing its allocation. Now, these are explicitly set to StorageScope::kGlobal, to match the previous behavior of StorageFlatten.
needs special handling.
assumed to have no strides defined.
Maintains earlier behavior of StorageFlatten, which allows buffer views to be mapped to higher dimension buffers, if the view extent is 1 in each extra dimension.
Previously, BufferBindUnwrapper passed fuzzy_match=true to ArgBinder::BindBuffer, which could change the number of dimensions. Now, all buffer dimensions should be updated prior to BufferBindUnwrapper, and it is an error to have mismatched dimensions in BufferBindUnwrapper.
1e981de
to
41b8eff
Compare
Latest round of changes, added another pass to remove assert statements that can be statically validated. These are placed by ArgBinder::Bind if it can't verify a constraint at the time when it binds a variable. If later variable substitutions allow the constraint to be statically verified, they can still remain in the final generated code. These didn't appear prior to the refactor, because |
ArgBinder::BindBuffer inserts these assert statements if they are not verifiable at the time of substitution. Previously, with one giant substitution, the assertions were verifiable at that time. After the refactor, with substitutions done in multiple stages for shape/stride/buffer_bind_scope, we need to clean up any assertions that are verifiable after all substitutions have occurred.
- Removed StorageFlattener::BufferEntry::RelIndex, behavior already handled by BufferShapeLegalize. - Improved comments and error messages. - Extracted duplicate behavior in BufferLoad/BufferStore handling in BufferShapeLegalize.
41b8eff
to
3d3ec42
Compare
A true Assert statement can be removed, but a false Assert statement requires CFA to give as a compile-time error. Since we only need the removal of true assert statements, skipping the CFA this time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks @Lunderberg!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @Lunderberg LGTM
Thanks @csullivan and @Lunderberg the PR has been merged. |
Parametrized it to get more detailed information while debugging failures in apache#9091, but isn't semantically part of that PR.
* [TE] Improved flexibility of ArgBinder::BindDLTensor Allowed a compact DLTensor to bind to a Buffer object that defines strides, if the strides defined correspond to a compact layout. * [TIR] Exposed ElemOffset as a member function of BufferNode. * [TE] Pulled shape determination out of StorageFlattener Previously, StorageFlattener would determine the shape of a physical buffer based on the extents of the BufferRealizeNode. Pulled these out into a separate BufferShapeLegalize pass. After this pass, all buffers have a shape that matches the buffer realization extents. * [TE] Refactor stride calculation out of StorageFlattener Previously, StorageFlattener would handle any attr::dim_align annotations. Now, this is pulled out into a separate BufferStrideLegalize pass. * [TE] Refactor thread scope propagation out of StorageFlattener. Previously, StorageFlattener would use the scope in IterVar to assign a scope to allocated buffers, where not otherwise defined. This has been pulled out into a separate ThreadScopePropagate pass. * [TE] Refactor buffer bind mapping out of StorageFlattener. Previously, StorageFlattener would look for `attr::buffer_bind_scope` to determine if a Buffer object is a view into another buffer, and would apply that mapping while making the Allocate/Store/Load nodes. Now, the mapping of buffer binds is pulled out into a separate BufferStrideUnwrapper pass. This also resolves an issue in which BufferLoad/BufferStore nodes that refer to a Buffer defined through `attr::buffer_bind_scope` would generate Load/Store nodes that point to the linked buffer, rather than the actual buffer. * [TIR] Removed checks on buffer->shape.size() Even after BufferShapeLegalize, rank-zero tensors may have an empty shape. * [TIR] Relaxed check on a bufferview's striding. Original refactoring requiring that a bufferview have no explicit striding, and instead take the striding from the buffer that it is viewing. Modified to allow bufferview to specify striding, so long as it is consistent with the viewed buffer's striding. This reproduces the behavior of StorageFlatten before the refactoring. * [TIR] Fixed StorageFlatten test for shape_legalize. AttrStmtNodes that contain rewritten Buffers need to be rewritten as well. * [TIR] Assigned storage scope The earlier stage of the refactor left a buffer's storage scope undefined if it's scope was not determined by the IterVar of a loop containing its allocation. Now, these are explicitly set to StorageScope::kGlobal, to match the previous behavior of StorageFlatten. * Updated ICHECK_EQ to CHECK_EQ for a test that depends on user-provided data. * Added comments in storage_flatten.cc, indicating why buffer_bind_scope needs special handling. * Updated comment with a few examples of where compact buffers are assumed to have no strides defined. * Updated following @csullivan's comments. * Added fuzzy mapping to the BufferShapeLegalize. Maintains earlier behavior of StorageFlatten, which allows buffer views to be mapped to higher dimension buffers, if the view extent is 1 in each extra dimension. * Updated BufferShapeLegalize, asserts need to be inside the buffer_bind_scope. * Pulled all shape-dependent behavior into BufferShapeLegalize. Previously, BufferBindUnwrapper passed fuzzy_match=true to ArgBinder::BindBuffer, which could change the number of dimensions. Now, all buffer dimensions should be updated prior to BufferBindUnwrapper, and it is an error to have mismatched dimensions in BufferBindUnwrapper. * Added another pass to remove verifiable assert statements. ArgBinder::BindBuffer inserts these assert statements if they are not verifiable at the time of substitution. Previously, with one giant substitution, the assertions were verifiable at that time. After the refactor, with substitutions done in multiple stages for shape/stride/buffer_bind_scope, we need to clean up any assertions that are verifiable after all substitutions have occurred. * Minor cleanup - Removed StorageFlattener::BufferEntry::RelIndex, behavior already handled by BufferShapeLegalize. - Improved comments and error messages. - Extracted duplicate behavior in BufferLoad/BufferStore handling in BufferShapeLegalize. * Updated to handle BufferRealizeNode with no defined bounds. * Updated to be less aggressive when checking AssertStmt A true Assert statement can be removed, but a false Assert statement requires CFA to give as a compile-time error. Since we only need the removal of true assert statements, skipping the CFA this time.
Parametrized it to get more detailed information while debugging failures in apache#9091, but isn't semantically part of that PR.
* [TE] Improved flexibility of ArgBinder::BindDLTensor Allowed a compact DLTensor to bind to a Buffer object that defines strides, if the strides defined correspond to a compact layout. * [TIR] Exposed ElemOffset as a member function of BufferNode. * [TE] Pulled shape determination out of StorageFlattener Previously, StorageFlattener would determine the shape of a physical buffer based on the extents of the BufferRealizeNode. Pulled these out into a separate BufferShapeLegalize pass. After this pass, all buffers have a shape that matches the buffer realization extents. * [TE] Refactor stride calculation out of StorageFlattener Previously, StorageFlattener would handle any attr::dim_align annotations. Now, this is pulled out into a separate BufferStrideLegalize pass. * [TE] Refactor thread scope propagation out of StorageFlattener. Previously, StorageFlattener would use the scope in IterVar to assign a scope to allocated buffers, where not otherwise defined. This has been pulled out into a separate ThreadScopePropagate pass. * [TE] Refactor buffer bind mapping out of StorageFlattener. Previously, StorageFlattener would look for `attr::buffer_bind_scope` to determine if a Buffer object is a view into another buffer, and would apply that mapping while making the Allocate/Store/Load nodes. Now, the mapping of buffer binds is pulled out into a separate BufferStrideUnwrapper pass. This also resolves an issue in which BufferLoad/BufferStore nodes that refer to a Buffer defined through `attr::buffer_bind_scope` would generate Load/Store nodes that point to the linked buffer, rather than the actual buffer. * [TIR] Removed checks on buffer->shape.size() Even after BufferShapeLegalize, rank-zero tensors may have an empty shape. * [TIR] Relaxed check on a bufferview's striding. Original refactoring requiring that a bufferview have no explicit striding, and instead take the striding from the buffer that it is viewing. Modified to allow bufferview to specify striding, so long as it is consistent with the viewed buffer's striding. This reproduces the behavior of StorageFlatten before the refactoring. * [TIR] Fixed StorageFlatten test for shape_legalize. AttrStmtNodes that contain rewritten Buffers need to be rewritten as well. * [TIR] Assigned storage scope The earlier stage of the refactor left a buffer's storage scope undefined if it's scope was not determined by the IterVar of a loop containing its allocation. Now, these are explicitly set to StorageScope::kGlobal, to match the previous behavior of StorageFlatten. * Updated ICHECK_EQ to CHECK_EQ for a test that depends on user-provided data. * Added comments in storage_flatten.cc, indicating why buffer_bind_scope needs special handling. * Updated comment with a few examples of where compact buffers are assumed to have no strides defined. * Updated following @csullivan's comments. * Added fuzzy mapping to the BufferShapeLegalize. Maintains earlier behavior of StorageFlatten, which allows buffer views to be mapped to higher dimension buffers, if the view extent is 1 in each extra dimension. * Updated BufferShapeLegalize, asserts need to be inside the buffer_bind_scope. * Pulled all shape-dependent behavior into BufferShapeLegalize. Previously, BufferBindUnwrapper passed fuzzy_match=true to ArgBinder::BindBuffer, which could change the number of dimensions. Now, all buffer dimensions should be updated prior to BufferBindUnwrapper, and it is an error to have mismatched dimensions in BufferBindUnwrapper. * Added another pass to remove verifiable assert statements. ArgBinder::BindBuffer inserts these assert statements if they are not verifiable at the time of substitution. Previously, with one giant substitution, the assertions were verifiable at that time. After the refactor, with substitutions done in multiple stages for shape/stride/buffer_bind_scope, we need to clean up any assertions that are verifiable after all substitutions have occurred. * Minor cleanup - Removed StorageFlattener::BufferEntry::RelIndex, behavior already handled by BufferShapeLegalize. - Improved comments and error messages. - Extracted duplicate behavior in BufferLoad/BufferStore handling in BufferShapeLegalize. * Updated to handle BufferRealizeNode with no defined bounds. * Updated to be less aggressive when checking AssertStmt A true Assert statement can be removed, but a false Assert statement requires CFA to give as a compile-time error. Since we only need the removal of true assert statements, skipping the CFA this time.
This started after noticing that StorageFlatten incorrectly handled BufferLoad/BufferStore nodes that pointed to a buffer defined in an
attr::buffer_bind_scope
annotation. Rather than adding more logic into the existingStorageFlattener
mutator, I split up the existing behavior into multiple independent mutators.This PR includes a series of commits, each of which refactors one of the behaviors out of the
StorageFlattener
class and into a separate class. While all of the transforms are called sequentially in thetir.transform.StorageFlatten
to maintain the same overall behavior, each transform results in a valid TIR tree.attr::dim_align
.attr::buffer_bind_scope
. Refactoring this behavior into a separate mutator was my original goal, in order to resolve the issue of BufferLoad/BufferStore nodes that point to bound buffers, but doing so required the previous three behaviors to also be refactored into separate mutators.This refactor will also help in the future, when introducing layout transformations.