Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix][TIR]fix symbolic strides lower #15986

Closed
wants to merge 2 commits into from
Closed

[Fix][TIR]fix symbolic strides lower #15986

wants to merge 2 commits into from

Conversation

JackWeiw
Copy link
Contributor

compact_buffer_region PASS modify shared buffer stride[0] to

T.int64(72) * T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)) and stride[1] is T.int64(72)
but in LowerOpaqueBlock PASS it report error:
InternalError: Check failed: (is_zero(floormod(buffer->strides[i - 1], buffer->strides[i]))) is false:

For more detaied discuss, see here

Another bug occurs in PASS InjectPTXAsyncCopy .
that is dst_offset.dtype could be int64, the dtype of PrimExpr(index_factor) would be set to default to int32.
cause dtype inconsistent when calling tir::Mul.

To reproduce the problem in InjectPTXAsyncCopy, see script here

@JackWeiw
Copy link
Contributor Author

CC @Lunderberg @wrongtest-intellif

def transformed_symbolic_strided_buffer_func(a: T.handle):
n = T.int64()
A = T.match_buffer(a, (1, n, 10240))
for i, j, k in T.grid(((n + T.int64(63)) // T.int64(64) * T.int64(4) + T.int64(7)) // T.int64(8), 2, 160):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the test case depend on T.int64 datatypes? If not, this would be much more readable by using T.int32. Because it is the default integer type in TVMScript, it wouldn't require the explicit type conversions. (e.g. (n + 63) instead of (n + T.int64(63)).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank u for the adivice, i've modified it using T.int32 and pulled out into padded_size = T.meta_var(T.min((n + T.int64(63)) // T.int64(64) * T.int64(64)) in the test case.

A = T.match_buffer(a, (1, n, 10240))
for i, j, k in T.grid(((n + T.int64(63)) // T.int64(64) * T.int64(4) + T.int64(7)) // T.int64(8), 2, 160):
A_pad_shared_dyn = T.allocate([1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72], "float32", "shared.dyn")
A_pad_shared_dyn_1 = T.decl_buffer((1, T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 64), data=A_pad_shared_dyn, strides=(T.int64(72) * T.min((n + T.int64(63)) // T.int64(64) * T.int64(64), T.int64(96)), 72, 1), scope="shared.dyn")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expression T.min((n + T.int64(63)) // T.int64(64) * T.int64(64) occurs frequently, and makes it difficult to read. Can this be pulled out into padded_size = T.meta_var(T.min((n + T.int64(63)) // T.int64(64) * T.int64(64))? The generated TIR will still contain the full expression, but the test case can be easier to read.

A_pad_shared_dyn = T.alloc_buffer((1, T.min((n + 63) // 64 * 64, 96), 64), "float32", strides=(72 * T.min((n + 63) // 64 * 64, 96), 72, 1), scope="shared.dyn")
for ax0, ax1 in T.grid(96, 64):
with T.block("A_pad_shared.dyn"):
T.where(i * 128 + j * 32 + ax0 < (n + 63) // 64 * 64)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like the same T.reads and T.writes annotations as would be automatically inferred from the block's body. Unless the test depends on a specific override to use non-default read/write annotations, it should be removed for readability.

def compacted_symbolic_strided_buffer_func(a: T.handle) -> None:
n = T.int64()
A = T.match_buffer(a, (1, n, 10240), "float32")
for i, j, k in T.grid(((n + 63) // 64 * 4 + 7) // 8, 2, 160):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated, the presence of this expression is kind of odd to me. Assuming this example came from a TIR printout, I would have expected ((n + 63) // 64 * 4 + 7) // 8 to be simplified to the equivalent (n + 127) // 128. The fact that it didn't simplify may indicate that I should take a look at the CanonicalSimplifier.

@@ -79,7 +80,7 @@ class PTXAsyncCopyInjector : public StmtMutator {
if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
Array<PrimExpr> args = {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
Array<PrimExpr> args = {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since your description mentions this as a separate bug, can it either be split out into a separate PR, or (since it is a relatively small change), have a test case added for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I open a new PR here . Please have a check.
I will open a new PR to fix dtype mismatch bug in PASS InjectPTXAsyncCopy after symbolic strides PR is merged

@JackWeiw JackWeiw closed this by deleting the head repository Oct 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants