-
Notifications
You must be signed in to change notification settings - Fork 52
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Re-write replaceSymbolicSizes using IdModel #2714
Conversation
In #2671, the replaceSymbolicSizes lowering pass calls ir_utils::replaceValue with a seemingly benign list of scalar Val replacements. However, an error is encountered because in replacing IterDomains whose extents should be replaced, we wind up erasing the definition of a Split output. We should instead preserve these definitions and just replace the output of that expression. Fixes #2671
!build --diff |
Marking as draft while I investigate some test failures |
This fixes some expand-related failures
!build --diff-bench |
This also fixes #2766. @jacobhinkle, please add this as a test as well. This is based on the repro of #2766.
|
Previously, the expand changed the ability to re-use smem. But now that we replace the exact mapped extents, we can _always_ reuse this smem.
This was found when implementing #2714. We currently do not bind TensorMetaData for input tensors to PrecomputedValues. This means we cannot evaluate expressions that contain them, which can lead to errors. This PR binds these metadata structs, which I think is the expected behavior.
Previously, this method only compared the pointers held by two StructHandles. This PR changes it to check that the name, number of fields, and the DataType and value of each field match. #2714 (comment)
!build |
This was found when implementing #2714. We currently do not bind TensorMetaData for input tensors to PrecomputedValues. This means we cannot evaluate expressions that contain them, which can lead to errors. This PR binds these metadata structs, which I think is the expected behavior.
!build --diff-bench |
The only test failures are known H100 thunder SDPA failures. |
@jacobhinkle I merged #2668 into main, updated this branch, and restarted CI. Just to make ensure there is no conflict. |
!build |
Good idea. Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have basic tests of the replacement? If not, please add some unit tests.
tensor_dim_map[orig_extent] = simplified_extent; | ||
} | ||
auto it = tensor_dim_map.find(simplified_extent); | ||
if (it != tensor_dim_map.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we doing this? Previously, there's tensor_dim_map
, and we update it with expr_simplification_map
. Now, it seems we are doing the opposite, which means if an extent is discovered to be a constant in expr_simplification_map
, it would be overwritten by the symbolic representation. Am I understanding it correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we are now doing more replacements than previously. Previously we would only replace extents that are exact mapped to input extents. This was a problem in the motivating example from #2702: see #2702 (comment). In that case it's "ceilDiv
" extent of a reshaped tensor that needs to be mapped as constant and it's not actually exact mapped with an input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that, but what happens if an input domain is discovered to be a constant in extent_simplification_map
. Wouldn't it be overwritten by the getMetaData
expr even if it's constant?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The purpose of this loop is to handle cases where the simplified extent is an input tensor dimension; it should ignore any ValGroup
s that simplify to constants or other values.
Suppose we have i0 = getMetaData[T0].logical_size[0]
in tensor_dim_map
and we find that i0
is equal to 5
so that extent_simplification_map[i0] = 5
. This loop is looping over extent_simplification_map
and reaches the entry i0 -> 5
. We then look up the simplified extent which is 5
in this case in tensor_dim_map
. Since 5
is not a tensor dim, it is not found so we don't update extent_simplification_map
.
Now suppose i1 = getMetaData[T0].logical_size[1]
in tensor_dim_map
and we find that another extent is mapped to it in extent_simplification_map
, e.g. extent_simplification_map[i3] = i1
. In this case, we will find i1
in tensor_dim_map
in this loop, and we will update it to extent_simplification_map[i3] = getMetaData[T0].logical_size[1]
.
I will add examples like these to the comments to make the code more clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see.
Why is the previous code not enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous code was composing in the other direction, so that only extents that started out in tensor_dim_map
were replaced. Specifically, that means only input tensor dimensions were replaced. But there can be other dimensions that are exact mapped with those dimensions, like if we do a reshape or resize operation before a BinaryOp
. In the new code, all dimensions get standardized instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, what about line 240 of the previous code? Doesn't it add a mapping that does not exist in tensor_dim_map
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Never mind, I misread the code. It's under the if
branch starting at line 237, not 236.
// Check that extents are properly replaced by replaceSymbolicSizes lowering | ||
// pass | ||
TEST_F(NVFuserTest, ReplaceSymbolicSizes) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this test which matches the behavior described in the comments of replaceSymbolicSizes
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
This uses IdModel to implement
replaceSymbolicSizes
. Extents are replaced with a single representative from their exact graph ValGroup with the following precedence:name()
.Fixes #2702. Fixes #2766