Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix][TIR]fix mul dtype mismatch #16010

Merged
merged 3 commits into from Oct 31, 2023
Merged

[Fix][TIR]fix mul dtype mismatch #16010

merged 3 commits into from Oct 31, 2023

Conversation

JackWeiw
Copy link
Contributor

Another bug occurs in PASS InjectPTXAsyncCopy .
that is dst_offset.dtype could be int64, the dtype of PrimExpr(index_factor) would be set to default to int32.
cause dtype inconsistent when calling tir::Mul.

To reproduce the problem in InjectPTXAsyncCopy, see script here

@JackWeiw
Copy link
Contributor Author

CC @Lunderberg @wrongtest-intellif
Dose this change needs a unittest? If needed could u give some suggestions about how to make the unittest?

@Lunderberg
Copy link
Contributor

It would probably be good to add a unit test. The best way to do so would be to define a TIR function that (before the change) would trigger the mismatched dtype bug when passed through tir.transform.InjectPTXAsyncCopy, then assert that the output contains the bugfix. The tvm.testing.CompareBeforeAfter utility is designed to make it easy to write this type of unit test (example link).

In this case, the buggy output would contain an extra multiplication step, due to the explicit construction of a tir::Mul node, while the fixed output would contain the multiplied value as a single integer value. In this case, a test case that contains the current tir.Mul output would be written as follows:

class TestMultiplicationNodesAreInligned(tvm.testing.CompareBeforeAfter):
    transform = tvm.tir.transform.InjectPTXAsyncCopy()

    def before(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16")):
        tx = T.launch_thread("threadIdx.x", 32)
        A_flattened = T.Buffer((4096,), "float16", data=A.data)
        A_shared = T.decl_buffer([4096], "float16", scope="shared")

        T.attr("default", "async_scope", 1)
        for i in range(16):
            A_shared[tx * 128 + i * 8 : tx * 128 + i * 8 + 8] = A_flattened[
                tx * 128 + i * 8 : tx * 128 + i * 8 + 8
            ]
        T.ptx_commit_group()
        T.ptx_wait_group(0)

    def expected(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16")):
        tx = T.launch_thread("threadIdx.x", 32)
        A_shared = T.decl_buffer([4096], "float16", scope="shared")
        for i in range(16):
            T.ptx_cp_async(
                "float16",
                A_shared.data,
                T.Mul(tx * 128 + i * 8, 1),
                A.data,
                tx * 128 + i * 8,
                16,
            )
        T.ptx_commit_group()
        T.ptx_wait_group(0)

With your fix applied, I suspect that this would fail due to the T.Mul, and can then be updated to the new output after your change.

@JackWeiw
Copy link
Contributor Author

It would probably be good to add a unit test. The best way to do so would be to define a TIR function that (before the change) would trigger the mismatched dtype bug when passed through tir.transform.InjectPTXAsyncCopy, then assert that the output contains the bugfix. The tvm.testing.CompareBeforeAfter utility is designed to make it easy to write this type of unit test (example link).

In this case, the buggy output would contain an extra multiplication step, due to the explicit construction of a tir::Mul node, while the fixed output would contain the multiplied value as a single integer value. In this case, a test case that contains the current tir.Mul output would be written as follows:

class TestMultiplicationNodesAreInligned(tvm.testing.CompareBeforeAfter):
    transform = tvm.tir.transform.InjectPTXAsyncCopy()

    def before(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16")):
        tx = T.launch_thread("threadIdx.x", 32)
        A_flattened = T.Buffer((4096,), "float16", data=A.data)
        A_shared = T.decl_buffer([4096], "float16", scope="shared")

        T.attr("default", "async_scope", 1)
        for i in range(16):
            A_shared[tx * 128 + i * 8 : tx * 128 + i * 8 + 8] = A_flattened[
                tx * 128 + i * 8 : tx * 128 + i * 8 + 8
            ]
        T.ptx_commit_group()
        T.ptx_wait_group(0)

    def expected(A: T.Buffer((32, 128), "float16"), B: T.Buffer((32, 128), "float16")):
        tx = T.launch_thread("threadIdx.x", 32)
        A_shared = T.decl_buffer([4096], "float16", scope="shared")
        for i in range(16):
            T.ptx_cp_async(
                "float16",
                A_shared.data,
                T.Mul(tx * 128 + i * 8, 1),
                A.data,
                tx * 128 + i * 8,
                16,
            )
        T.ptx_commit_group()
        T.ptx_wait_group(0)

With your fix applied, I suspect that this would fail due to the T.Mul, and can then be updated to the new output after your change.

Thank you for your suggestion. There is something confuse me.
I've put the TestMultiplicationNodesAreInligned into test_tir_transform_inject_ptx_async_copy.py and i run this test in the buged version. All 5 tests are passed (original 4).
Maybe because in your script, u did not set dtype of stride to be int64 which cause this bug.
So, what i need to do is modify before part to get failed in the buged version, and update expected produced after the change?

@Lunderberg
Copy link
Contributor

Ah, I didn't realize that was a necessary step of reproducing the bug. The way to use int64 for the integer literals is to wrap each of them with T.int64( NUM ). Unfortunately, there isn't a way to change it for all integer literals across a primfunc.

@JackWeiw
Copy link
Contributor Author

Ah, I didn't realize that was a necessary step of reproducing the bug. The way to use int64 for the integer literals is to wrap each of them with T.int64( NUM ). Unfortunately, there isn't a way to change it for all integer literals across a primfunc.

sorry for the delayed response. I have change buffer access pattern in your script by replacing it with a ramp node.
It would be wrong if simply modify dtype by wrap each of the integer literals with T.int64( NUM )
问题
2
The script i commited have the right access pattern, as can be seen below after PASS tir.transform.VectorizeLoop()

3

@Lunderberg
Copy link
Contributor

Ooh, that's an interesting failure mode, and good debugging on catching it. Looks like the root cause of that extra step is here, where the Range(0, extent) uses the default conversion of 0 into PrimExpr, resulting in an int32 datatype. It should use Range(IntImm(0, extent->dtype), extent) in order to use the int64 dtype of the extent. I'm going to make a fix for that, since the tx variable should have the same datatype as is used in the attribute. (This shouldn't impact your test case, as it explicitly provides the same behavior with the T.Cast.)

@JackWeiw
Copy link
Contributor Author

Ooh, that's an interesting failure mode, and good debugging on catching it. Looks like the root cause of that extra step is here, where the Range(0, extent) uses the default conversion of 0 into PrimExpr, resulting in an int32 datatype. It should use Range(IntImm(0, extent->dtype), extent) in order to use the int64 dtype of the extent. I'm going to make a fix for that, since the tx variable should have the same datatype as is used in the attribute. (This shouldn't impact your test case, as it explicitly provides the same behavior with the T.Cast.)

Thanks for the review. Can this PR be merged then?

@Lunderberg
Copy link
Contributor

Thanks for the review. Can this PR be merged then?

Looks like there were a couple of CI steps that required approval to start. I've started them, and after they finish, the PR can be merged. These are compile-only tests, so they should be done relatively quickly. I'll keep an eye out for when they finish, but feel free to ping me if you notice them finish before I do.

Also, I wanted to say thank you for the extra work in splitting out the separate PRs and adding the unit tests. It can be a bit tedious, but it is very much appreciated in maintaining a testable code base with history suitable for git bisect.

@Lunderberg
Copy link
Contributor

resulting in an int32 datatype.

On closer inspection, the DataType::Int(32) is also explicitly specified here, even if the Range issue is sorted out, and looks to be based on cuda definitions of this type as int. Therefore, while the Range(0, T.int64(value)) usage should be sorted out, the datatype of tx should remain T.int32 as in your unit test, with explicit casts to int64 as needed.

@JackWeiw
Copy link
Contributor Author

it is very much appreciated in maintaining a testable code base with history suitable for git bisect

Yeah i see, it confuse me quite a lot before. Thanks for the explanation

@Lunderberg Lunderberg merged commit 878a611 into apache:main Oct 31, 2023
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants