Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-write replaceSymbolicSizes using IdModel #2714

Merged
merged 30 commits into from
Aug 26, 2024
Merged

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Jul 30, 2024

This uses IdModel to implement replaceSymbolicSizes. Extents are replaced with a single representative from their exact graph ValGroup with the following precedence:

  1. Constants are preferred
  2. If no constants exist, prefer the extents of fusion inputs.
  3. Ties are broken by choosing the scalar with the smallest name().

Fixes #2702. Fixes #2766

In #2671, the replaceSymbolicSizes lowering pass calls
ir_utils::replaceValue with a seemingly benign list of scalar Val
replacements. However, an error is encountered because in replacing
IterDomains whose extents should be replaced, we wind up erasing the
definition of a Split output. We should instead preserve these
definitions and just replace the output of that expression.

Fixes #2671
@jacobhinkle jacobhinkle changed the title Id model replace sizes Re-write replaceSymbolicSizes using IdModel Jul 30, 2024
@jacobhinkle jacobhinkle changed the base branch from main to mutator_sibling_ids July 30, 2024 00:48
@jacobhinkle
Copy link
Collaborator Author

!build --diff

Base automatically changed from mutator_sibling_ids to main August 1, 2024 19:14
@jacobhinkle jacobhinkle marked this pull request as ready for review August 1, 2024 19:14
@jacobhinkle jacobhinkle marked this pull request as draft August 1, 2024 19:22
@jacobhinkle
Copy link
Collaborator Author

Marking as draft while I investigate some test failures

@jacobhinkle
Copy link
Collaborator Author

!build --diff-bench

@naoyam
Copy link
Collaborator

naoyam commented Aug 7, 2024

This also fixes #2766.

@jacobhinkle, please add this as a test as well. This is based on the repro of #2766.

// Repro of #2766 
TEST_F(NVFuserTest, SmallOuterBlockReductionRepro) {
  std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
  auto& fusion = *fusion_ptr;
  FusionGuard fg(&fusion);

  std::vector<int64_t> shape{100, 2, 128};

  auto tv0 = makeContigTensor(2);
  fusion.addInput(tv0);

  auto tv1 = reshape(
      tv0,
      {IrBuilder::create<Val>(shape[0]),
       IrBuilder::create<Val>(shape[1]),
       IrBuilder::create<Val>(shape[2])});
  auto tv2 = sum(tv1, {1});
  fusion.addOutput(tv2);

  // Previously, after the extent replacement of the lowering, the reduction reference tensor got a reduction domain of a static size, which is just 1, but the pre-reshape tensors still kept using symbolic extents. Before #2771, the scheduler decided to not use TIDy because the reference tensor has a static size of 1, but since the other tensors still had dynamic sizes, it resulted in the dynamic allocation error.

  auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
  auto t0 = at::randn({shape[0] * shape[1], shape[2]}, options);
  std::vector<c10::IValue> inputs({t0});

  FusionExecutorCache fec(std::move(fusion_ptr));
  auto outputs = fec.runFusionWithInputs(inputs);

  testValidate(fec.fusion(), outputs, inputs, __LINE__, __FILE__);
}

Previously, the expand changed the ability to re-use smem. But now that
we replace the exact mapped extents, we can _always_ reuse this smem.
jacobhinkle added a commit that referenced this pull request Aug 20, 2024
This was found when implementing #2714. We currently do not bind
TensorMetaData for input tensors to PrecomputedValues. This means we
cannot evaluate expressions that contain them, which can lead to errors.
This PR binds these metadata structs, which I think is the expected
behavior.
jacobhinkle added a commit that referenced this pull request Aug 20, 2024
Previously, this method only compared the pointers held by two
StructHandles. This PR changes it to check that the name, number of
fields, and the DataType and value of each field match.

#2714 (comment)
@jacobhinkle jacobhinkle marked this pull request as ready for review August 21, 2024 00:34
@jacobhinkle
Copy link
Collaborator Author

!build

jacobhinkle added a commit that referenced this pull request Aug 22, 2024
This was found when implementing #2714. We currently do not bind
TensorMetaData for input tensors to PrecomputedValues. This means we
cannot evaluate expressions that contain them, which can lead to errors.
This PR binds these metadata structs, which I think is the expected
behavior.
@jacobhinkle
Copy link
Collaborator Author

!build --diff-bench

@jacobhinkle
Copy link
Collaborator Author

The only test failures are known H100 thunder SDPA failures.

@liqiangxl
Copy link
Collaborator

liqiangxl commented Aug 23, 2024

@jacobhinkle I merged #2668 into main, updated this branch, and restarted CI. Just to make ensure there is no conflict.

@liqiangxl
Copy link
Collaborator

!build

@jacobhinkle
Copy link
Collaborator Author

@jacobhinkle I merged #2668 into main, updated this branch, and restarted CI. Just to make ensure there is no conflict.

Good idea. Thanks.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have basic tests of the replacement? If not, please add some unit tests.

csrc/device_lower/pass/replace_size.cpp Outdated Show resolved Hide resolved
tensor_dim_map[orig_extent] = simplified_extent;
}
auto it = tensor_dim_map.find(simplified_extent);
if (it != tensor_dim_map.end()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we doing this? Previously, there's tensor_dim_map, and we update it with expr_simplification_map. Now, it seems we are doing the opposite, which means if an extent is discovered to be a constant in expr_simplification_map, it would be overwritten by the symbolic representation. Am I understanding it correctly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we are now doing more replacements than previously. Previously we would only replace extents that are exact mapped to input extents. This was a problem in the motivating example from #2702: see #2702 (comment). In that case it's "ceilDiv" extent of a reshaped tensor that needs to be mapped as constant and it's not actually exact mapped with an input.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that, but what happens if an input domain is discovered to be a constant in extent_simplification_map. Wouldn't it be overwritten by the getMetaData expr even if it's constant?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of this loop is to handle cases where the simplified extent is an input tensor dimension; it should ignore any ValGroups that simplify to constants or other values.

Suppose we have i0 = getMetaData[T0].logical_size[0] in tensor_dim_map and we find that i0 is equal to 5 so that extent_simplification_map[i0] = 5. This loop is looping over extent_simplification_map and reaches the entry i0 -> 5. We then look up the simplified extent which is 5 in this case in tensor_dim_map. Since 5 is not a tensor dim, it is not found so we don't update extent_simplification_map.

Now suppose i1 = getMetaData[T0].logical_size[1] in tensor_dim_map and we find that another extent is mapped to it in extent_simplification_map, e.g. extent_simplification_map[i3] = i1. In this case, we will find i1 in tensor_dim_map in this loop, and we will update it to extent_simplification_map[i3] = getMetaData[T0].logical_size[1].

I will add examples like these to the comments to make the code more clear.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see.

Why is the previous code not enough?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous code was composing in the other direction, so that only extents that started out in tensor_dim_map were replaced. Specifically, that means only input tensor dimensions were replaced. But there can be other dimensions that are exact mapped with those dimensions, like if we do a reshape or resize operation before a BinaryOp. In the new code, all dimensions get standardized instead.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, what about line 240 of the previous code? Doesn't it add a mapping that does not exist in tensor_dim_map?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, I misread the code. It's under the if branch starting at line 237, not 236.

Comment on lines +8774 to +8776
// Check that extents are properly replaced by replaceSymbolicSizes lowering
// pass
TEST_F(NVFuserTest, ReplaceSymbolicSizes) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this test which matches the behavior described in the comments of replaceSymbolicSizes.

Copy link
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

Copy link
Collaborator

@liqiangxl liqiangxl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@jacobhinkle jacobhinkle merged commit 3b61042 into main Aug 26, 2024
5 checks passed
@jacobhinkle jacobhinkle deleted the id_model_replace_sizes branch August 26, 2024 18:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants