-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TE] Fix MakeLoopNest for warp memory #5382
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this condition mean that "warp" is actually treated as "local"? I don't know whether there is a such use case. If so, it may be better to do sanity check, convert from "warp" to "local", and give warning before making loop nests. It is at least worth a test to lock the use case down. If not, giving error message here seems more reasonable. A test is also worthy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me explain this piece of code:
Line 167: If we are using a local storage, then each iteration variable
iv
accounts to the accessing offset, but thread indices or block indices do not. For example, when accessing a local buffera
, it should be something likea[i]
, instead ofa[blockIdx.x, threadIdx.x, i]
.Line 168: There are two cases:
b[blockIdx.x, threadIdx.x, i]
.c
, it should bec[threadIdx.x, i]
, instead ofc[blockIdx.x, threadIdx.x, i]
.Now we come to warp storage. Ideally, all the variables inside a warp should account to the offset, but those outside a warp do not. It is complex to determine whether a variable is inside a warp, but usually only
threadIdx.x
is inside, so we made the assumption, and give a warning otherwise.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As for the reason of making the message a warning rather than an error, consider 3 situations:
threadIdx.x
only. NothreadIdx.y
orthreadIdx.z
, no warning is shown.threadIdx.x
andthreadIdx.y
, and the extent ofthreadIdx.x
is 32 (warp size), which is a common use case. Now a warning is shown, but the code is correct.threadIdx.x
andthreadIdx.y
, but the extent ofthreadIdx.x
< 32, now a warning is shown, and the code is wrong.We can still proceed in Situation 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the elaboration.
I added line 167/168 a few weeks ago. I am glad that this PR is fixing for the "warp" scope, as I wasn't very sure about it and left it with the old behavior to avoid one test failure. Your explanation is almost same as what I thought: when a StorageScope is only visible to a ThreadScope (such as shared to block or local to thread), we don't access it with thread IterVars (such as blockIdx or threadIdx), i.e. that they are not account to the offset.
If my idea makes sense, I am confused by this statement
and the second situation. In my mind, the conceptual "warp" storage is accessible by multiple threads, it makes sense to keep threadIdx IterVar to offset the access. But I don't understand why things change when it comes to outside a warp. In the 2nd situation, is cross "warp" access allowed?
In the 1st situation, does threadIdx.x's extent matter here or is it handled by warp memory lowering?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is an example for the 2nd situation: Consider extent of
threadIdx.x
= 32, extent ofthreadIdx.y
= 4. Now there are 4 warps, each consisting of 32 threads.No, all of the 128 threads access a warp storage
a
, but threads in different warps are actually accessing different group of registers, although they are all calleda
. Only threads in the same warp are accessing the same group of registers. This is what I mean by inside/outside. Just like shared memory, different blocks access different piece of SRAM, even their variable name is the same.It dose not matter.
lower_warp_memory
pass will handle different exntents.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
threadIdx.x means a lower level of hierarchy than threadIdx.y and threadIdx.z for "warp" scope -- "warp" memory is local to threadIdx.y and threadIdx.z, but shared by threadIdx.x. One idea for you to consider is to change threadIdx.x's rank to 2 and keep threadIdx.y and threadIdx.z's rank 1. It will avoid special casing "warp" here, but may cause other regression.
Does
dom
and/orbind_iv->dom
have the extent for the checks?This example clarifies a lot. Is it possible to make it a test to cover the "warp" specific code here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. Currently there are some hard-coded rank. It may be difficult to modify the rank definition.
Yes,
dom
does. But where can we get the warp size? Warp size is not a constant value, it is 32 for NVIDIA GPU and 64 for AMD GPU. The warp size information is stored inTarget
class. SinceMakeLoopNest
is designed to be a target-irrelevant function, I don't think it a good idea to add an argument to specify the target.Actually I did try to run this example, but TVM somehow generated a correct code. It only led to a wrong boundary checking bug in a very complex program.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you about
MakeLoopNest
being target independent. My first thought is to defer the check andthreadIdx.x
vs.threadIdx.y/z
handling till LowerWarpMemory(), such as adjusting indices in WarpAccessRewriter(), since it does target specific "warp" lowering. Maybe substitutethreadIdx.y/z
with0
in the group indices for the supported cases and give error otherwise? However, it seems not the case, as your 3rd situation ends with incorrect code instead of an error from LowerWarpMemory(). But I don't know the reason.I see. Does your simplified case trigger the warning? If so, checking for the warning can guard your changes from being accidentally deleted or skipped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually I mean:
lower_warp_memory
, which will lead to an error. Plus, we will see a warning after this PR.No matter how, I still have no idea how the 2nd situation ends up with incorrect code.
Actually no. Let's have a look at some details.
Here's my simplified example:
And here's the result, which is unexpectedly correct before this PR.
The
if (stage->scope == "warp" && ts.rank == 1)
branch in the modified code is only triggered once, wherets.dim_index == 0
. I don't know why thets.dim_index == 1
IterVar
is ignored.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified the simplified example a little to bind threadIdx.y in the warp stage to let threadIdx.y pass through the new code.
The three situations make a good summary.
1st one already has at least one test in tests/python/unittest/test_tir_transform_lower_warp_memory.py.
I hope the above code can lock down the 2nd situation and probably the error for the 3rd one by reducing threadIdx.x's extent.
Warp has been special cased in several places, e.g. in bound.cc and here before this PR. I tried to push back to add more special case code, but I am Ok the accept the current change. Please try to add tests.