New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Create loops according to storage scope and thread hierarchies #5190
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
yongfeng-nv
force-pushed
the
thread-local-bound
branch
from
April 5, 2020 06:22
9ca46cf
to
9600ad4
Compare
yongfeng-nv
force-pushed
the
thread-local-bound
branch
from
April 6, 2020 06:36
0e5da1e
to
b2c434b
Compare
yongfeng-nv
changed the title
WIP: Fix thread local bound (Don't merge)
Create loops according to storage scope and thread hierachies.
Apr 6, 2020
yongfeng-nv
changed the title
Create loops according to storage scope and thread hierachies.
Create loops according to storage scope and thread hierarchies
Apr 6, 2020
@Hzfengsy @vinx13 @ZihengJiang can you help to take a look? |
…al_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.
tqchen
approved these changes
Apr 10, 2020
Thanks @yongfeng-nv ! this PR is now merged |
trevor-m
pushed a commit
to trevor-m/tvm
that referenced
this pull request
Apr 16, 2020
…e#5190) * Set IterVar index to 0 for local thread bound IterVars. * Lint fix * Use rank instead of scope name to predicate. Add tests. * Handle cases other than local/threadIdx. * Turn warp to the old behavior. * Modify test to cover global/blockIdx. * Fix a typo. * Update test_te_schedule_ops.py with more testing coverage in test_local_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.
zhiics
pushed a commit
to neo-ai/tvm
that referenced
this pull request
Apr 17, 2020
…e#5190) * Set IterVar index to 0 for local thread bound IterVars. * Lint fix * Use rank instead of scope name to predicate. Add tests. * Handle cases other than local/threadIdx. * Turn warp to the old behavior. * Modify test to cover global/blockIdx. * Fix a typo. * Update test_te_schedule_ops.py with more testing coverage in test_local_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.
dpankratz
pushed a commit
to dpankratz/incubator-tvm
that referenced
this pull request
Apr 24, 2020
…e#5190) * Set IterVar index to 0 for local thread bound IterVars. * Lint fix * Use rank instead of scope name to predicate. Add tests. * Handle cases other than local/threadIdx. * Turn warp to the old behavior. * Modify test to cover global/blockIdx. * Fix a typo. * Update test_te_schedule_ops.py with more testing coverage in test_local_stage_predicate; remove test_schedule_schedule_ops.py which was added by mistake.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The following small example generates wrong IR and code:
The IR:
The direct problem is a wrong predicate for stage B:
A more fundamental cause is that loop nest creation doesn't consider thread/memory hierarchies. In this example, stage B's scope is local. InferBound, needRelax in particular, doesn't relax the IterVar binding to a threadIdx, because the memory space is only used within the thread. However, when creating loops, schedule_ops uses the IterVar to offset the index to the memory regardless hierarchies. This is wrong, as the code uses threadIdx to address local memory.
The fix is straightforward -- MakeLoopNest adds the thread bound IterVar to the index only if the storage rank is lower than the thread rank (same the logic as needRelax() in bound.cc). vthread doesn't need this change, as it doesn't reduce storage scope, i.e. that vthreads still share one local memory space.
Similar issue was identified in this thread: https://discuss.tvm.ai/t/how-to-avoid-allocating-the-wrong-amount-of-registers-in-cuda-scheduling-in-this-example/5310, although it has more to address. I add this example as a new test.
Questions: