Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extra S->S copy introduced by tmaLoad->broadcast->mma. #1628

Closed
zasdfgbnm opened this issue Jan 15, 2024 · 14 comments · Fixed by #1672, #1736, #1776 or #1793
Closed

Extra S->S copy introduced by tmaLoad->broadcast->mma. #1628

zasdfgbnm opened this issue Jan 15, 2024 · 14 comments · Fixed by #1672, #1736, #1776 or #1793

Comments

@zasdfgbnm
Copy link
Collaborator

zasdfgbnm commented Jan 15, 2024

When we define matmul, we have inputs of shape [M, K] and [K, N], and we broadcast it to [M, K, 1] and [1, K, N], then we do a mul and sum to obtain the matmul result. From [M, K] to [M, K, 1], there must be a BroadcastOp. Unfortunately, this does not work with Hopper, because for fast Hopper matmul kernel, we must use TMA for global -> shared memory loading and use MMA instructions to directly pull data from shared mem. If we are unable to fuse the global -> shared with broadcasting op, we will end up having a kernel like:

T1_s = tmaLoad(T0_g);
T2_s = broadcast(T1_s);
... = mma(T2_s, ...);

That is, we need to double the smem usage to store the same data twice (T1 and T2), one for TMA consumer, another for MMA producer, and do a S->S copy between these two memory.

I am still learning but according to my current understanding, the smem data written by TMA is supposed to be used by MMA directly, so this extra S->S copy introduced by the BroadcastOp does not make sense.

What I am proposing is to go through the same process as I did in the past for TransposeOp: In the past, transposes must be done in a TransposeOp, but after my change, any op can have a fused transpose by specifying the order in its rFactor domain. Most commonly, the op having a fused transpose is a LoadStoreOp, but it is not required to be so.

I want broadcasting (and to be consistent, squeeze as well) to follow the same contract: There will no longer be a BroadcastOp or SqueezeOp, instead, broadcasting IterDomains is allowed to be created or annihilated at any TensorView. (The term "created" and "annihilated" is borrowed from particle physics, because I am feeling that broadcasting IterDomain is like a Boson, it can be created from nowhere, and it can just disappear as it wants). For example, if we want to do a broadcast, we can define:

T1 = LoadStoreOp(T0)

T0: root domain [I1, I2]
T1: root domain [I3, I4]
    rFactor domain [I3, b5, I4]

we can also do fused squeeze-broadcast

T1 = LoadStoreOp(T0)

T0: root domain [I1, b3, I2]
T1: root domain [I4, b5, I6]
    rFactor domain [I4, b7, I6]
// squeeze b5, broadcast b7

The main use case I want to support is LoadStoreOp with fused broadcasting or squeeze, but I do not see any reason why LoadStoreOp is special, so my expectation is, when this work is done, an arbitrary tensor op can have fused broadcasting or squeeze.

Calling for review: @naoyam @jacobhinkle @wujingyue @protonu @drzejan2

@naoyam
Copy link
Collaborator

naoyam commented Jan 15, 2024

QQ:

T1_s = tmaLoad(T0_g);
T2_s = broadcast(T1_s);
... = mma(T2_s, ...);

That is, we need to double the smem usage to store the same data twice (T1 and T2), one for TMA consumer, another for MMA producer, and do a S->S copy between these two memory.

Doesn't the alias analysis allow reusing the same buffer? (CC: @jacobhinkle )

@jacobhinkle
Copy link
Collaborator

QQ:

T1_s = tmaLoad(T0_g);
T2_s = broadcast(T1_s);
... = mma(T2_s, ...);

That is, we need to double the smem usage to store the same data twice (T1 and T2), one for TMA consumer, another for MMA producer, and do a S->S copy between these two memory.

Doesn't the alias analysis allow reusing the same buffer? (CC: @jacobhinkle )

In theory it could. Currently though, broadcasts prevent inner aliasing like this since they can alter indexing in ways that make it difficult to guarantee the alias is safe. We could try updating that logic though.

@zasdfgbnm
Copy link
Collaborator Author

In theory it could. Currently though, broadcasts prevent inner aliasing like this since they can alter indexing in ways that make it difficult to guarantee the alias is safe. We could try updating that logic though.

Besides, that, for this program, the last read of T1_s is line 2 (T2_s = broadcast(T1_s);), the first write of T2_s is also this line. IIRC, [first write, last read] is considered a closed interval, so the life of T1_s and T2_s has overlap and can not be reused.

@zasdfgbnm
Copy link
Collaborator Author

Doesn't the alias analysis allow reusing the same buffer? (CC: @jacobhinkle )

Even if it does, I would still prefer not to rely on alias analysis. One reason is, compared with not creating two buffers in the first place, this can be fragile. Second reason is, both T1_s and T2_s needs to have swizzle in it, and the propagation of swizzle can be tricky (1. currently, transform propagator does not propagate swizzle. 2. we just want to propagate swizzle from T2_s to T1_s, but not further to T0_g, this is doable but not convenient)

@wujingyue
Copy link
Collaborator

To avoid https://xyproblem.info/, I think the real problem here is that no IR today represents a ldmatrix with an implicit broadcast.

One solution to that is to overload LoadStoreOp to support implicit broadcasting, which is part of @zasdfgbnm 's suggestion. In fact, LoadStoreOp has already been overloaded to represent semantics at different levels, from identity in math, to transferring data between address spaces, to transferring data between different layouts. #1343 is a similar example showing that LoadStoreOp was made to support a non-trivial rfactor domain so it can represent ldmatrix.trans.

Another common solution is to introduce a separate op for ldmatrix. We haven't seemed to discuss that yet -- what do you all think?

That being said, killing BroadcastOp is an orthogonal heavy lifting for arguable wins, so I wouldn't do it from the practical perspective. There's merit to keeping high-level op semantics simple, separate and composable. It makes analysis and optimization easier. For example, AliasFinder today is able to separately handle LoadStoreOp, BroadcastOp, and SqueezeOp. (IMHO, it would be even cleaner if we hadn't removed TransposeOp.) To help clarify the contract between BroadcastOp and LoadStoreOp, we could restrict BroadcastOp to only be used in certain stages, e.g., before device lowering.

@zasdfgbnm
Copy link
Collaborator Author

zasdfgbnm commented Jan 16, 2024

I think the real problem here is that no IR today represents a ldmatrix with an implicit broadcast.

Almost correct, except that is is not a ldmatrix but a cp.async.bulk.tensor.

Another common solution is to introduce a separate op for ldmatrix. We haven't seemed to discuss that yet -- what do you all think?

As I just mentioned, what I wanted was for cp.async.bulk.tensor, not for ldmatrix.trans. So let me split this question into two questions:

  1. What if we have not killed TransposeOp, but instead, we created a separate LdMatrixOp, what benefit will we get and what trouble will we have?
  2. How about cp.async.bulk.tensor? Does the reasoning of ldmatrix applies to cp.async.bulk.tensor?

For 1, AFAICT, if we have not killed TransposeOp the benefit we will get is:

  • LoadStoreOp will be stupid and simple: it is just a very trivial op that will not change the semantics of anything.
    • This will make some analysis simpler, because it would be just a if (expr->isA<LoadStoreOp>()) { return false; }, instead of having to go through the IterDomains and check if there are transposes.
    • But be warned that we still can not assume LoadStoreOp is a no-op, because the allocation domain of the producer and consumer might be different. LoadStoreOp means "no semantic permutation", but not "no physical permutation".

But the thing that will lose is:

  • We will be unable to vectorize transposes. For example, if I have T1_s = transpose(T0_l), I won't be able to vectorize the write to smem.
  • Needs special handling TransposeOp and LdMatrixOp in PairwiseRootDomainMap
  • We will be unable to use cacheBefore or cacheAfter for smem->register loading in the matmul scheduler.
  • Without allowing a fused transpose in MmaOp, we can not support NN format of matmul, because the output of matmul will be [N, M] instead of [M, N]

All the above points are solvable, but the solution is mostly like replicating whatever already exist for LoadStoreOp and applying it to TransposeOp and LdMatrixOp. So in general, at least at the time when I did that change, killing TransposeOp is a net reduction of the overall complexity of the system. As we can see in #148, 200+ lines of code is removed.

For 2, I am still typing, will post in the next reply.

@zasdfgbnm
Copy link
Collaborator Author

(continue my reasoning for 2)

Regarding cp.async.bulk.tensor, unlike ldmatrix, which is a very specific op that does a very specific thing and is only useful to matmul kernels, cp.async.bulk.tensor is much wider and will be widely used for all kernels (however, the need for fusing a broadcast is limited to matmul). For most of the time, cp.async.bulk.tensor will be used for its non-broadcasting case. From this perspective, cp.async.bulk.tensor is mostly just a LoadStoreOp.

Another side benefit of making LoadStoreOp capable of having a fused broadcast is, I can make ldmatrix have a fused broadcast as well. This way, I can avoid a line of code like T2[i1] = T1[i1] generated from broadcast. This improves code readability, because when a human writes code, they will not write something like that.

===================

From a higher level perspective, another reason why I want to make a broadcasting IterDomain capable of being created or annihilated arbitrarily is because I don't care when a broadcasting IterDomain start to exist. What is interesting to me is when it is resolved (i.e. a broadcasting ID is added to a concrete ID and becomes concrete). So in general, I believe we should focus our attention on broadcasting concretization, not the appearance of a new broadcasting IterDomain. Making this move de-emphasize the appearance of a broadcasting IterDomain.

@jacobhinkle
Copy link
Collaborator

I favor removing these ops since they are redundant: the TensorDomains hold all the info to infer a squeeze or broadcast, just like they do for transpose.

My understanding is that the broadcast op would become like set, except with an rfactor domain introduced. In that case, in order to simplify cases like the double-smem one above, would we need a pass to remove those trivial LoadStoreOps and instead modify the upstream tensor's rfactor domain? Explicitly, this is the case you listed @zasdfgbnm

T1_s = tmaLoad(T0_g);
T2_s = broadcast(T1_s);
... = mma(T2_s, ...);

Currently the definition of T2_s is a BroadcastOp but instead it would become a LoadStoreOp that we could try to remove in another lowering pass, IIUC. Or is it possible for us to provide some more generic way to modify the output of general ops to let us squeeze or insert broadcasts, which could be applied to not just tmaLoad...

@naoyam
Copy link
Collaborator

naoyam commented Jan 18, 2024

Asked Xiang to schedule a meeting, but first of all I'd emphasize Jingyue's comment:

There's merit to keeping high-level op semantics simple, separate and composable. It makes analysis and optimization easier.

Technically, it is possible to mix many of IterDomain ops with TensorView ops. For example, we could mix Resize with normal pointwise ops. However, in principle, I believe we should keep each op as simple as possible without overloading too many capabilities.

So, in principle, I'd keep BroadcastOp as is and oppose to let LoadStoreOp to take it over. However, I'd need to understand if that's really the only way to work around the problems.

One immediate concern on removing BroadcastOp is, if we do so, we would lose track of mappings between broadcast domains over producer-consumer tensors. Per Xiang:

So in general, I believe we should focus our attention on broadcasting concretization, not the appearance of a new broadcasting IterDomain. Making this move de-emphasize the appearance of a broadcasting IterDomain.

It is true that the concretization is important information but we also track which domains are eventually concretized. I'm not sure if this information is not necessary. I need to think about it more.

@wujingyue
Copy link
Collaborator

All the above points are solvable, but the solution is mostly like replicating whatever already exist for LoadStoreOp and applying it to TransposeOp and LdMatrixOp

Get it -- that's a fair and practical point. Given cp.async.bulk and ldmatrix are already implemented as LoadStoreOp's subtypes, extending it to support implicit broadcasting seems to have higher ROI than creating a separate op.

@zasdfgbnm
Copy link
Collaborator Author

Just an update about our recent offline discussion:

We decided to make MatmulScheduler::canScheduleCompileTime reject the BroadcastOp and take the mma, so that the fusion inputs will become 3D (for example [M, 1, K] and [1, K, N]). This way, there won't be the problem as described here. This approach also helps with #1707

image

zasdfgbnm added a commit that referenced this issue Feb 8, 2024
This PR renames `matmulAtInput` into `matmulAtInput2D`, explicitly
showing that it generates 2D inputs. This PR also adds a
`matmulAtInput3DTuring`, which is used to generate the 3D fusion inputs
(for example `[M, 1, K]` and `[1, K, N]`) for matmul. The `MmaTest` for
Turing and Ampere is modified to exclude the `BroadcastOp` and use the
3D version for generating fusion inputs. This is only the initial step
for making `scheduleMatmul` schedule a fusion not containing
`BroadcastOp`, I intentionally keep it small. Other changes will be
added in followup PRs.

Fixes #1628
@wujingyue wujingyue changed the title Kill BroadcastOp and make broadcasting an ID op instead of a tensor op Extra S->S copy introduced by tmaLoad->broadcast->mma. Feb 8, 2024
@wujingyue
Copy link
Collaborator

Great! Just to be more precise, I changed the title to reflect the problem not an outdated proposal.

@zasdfgbnm zasdfgbnm reopened this Feb 8, 2024
@zasdfgbnm
Copy link
Collaborator Author

Great! Just to be more precise, I changed the title to reflect the problem not an outdated proposal.

Thank you for updating the title!

===

This issue is not resolved. Reopen it now.

cowanmeg added a commit to samnordmann/Fuser that referenced this issue Feb 13, 2024
* print bandwidth when perf_debug_verbose is true (NVIDIA#1689)

print bandwidth when `perf_debug_verbose` is true.

* in vectorization validation, add err msg if tv has no definition (NVIDIA#1690)

check the existence of tv definition in vectorization validation

* Accomodate Reduction IterDomains when concretizing reshape extents (NVIDIA#1692)

We register extents for concretization when we concretize reshape. In
order to do that, we line up `IterDomain`s in the symbolic reshaped TV
and the new, concretized one. In cases where the concretized reshape is
trivial, such as when the output shape is the same as the input, we do
not create a new TV. In those cases, we will have the input to the
original `ViewOp` as the concretized output. That input TV might have
reduction domains, as in the provided test, in which case we need to
filter those out when doing this alignment. This small PR just
implements that filtering.

Fixes NVIDIA#1691.

* `MmaOp::evaluate` method (NVIDIA#1675)

* Fix some typos. (NVIDIA#1700)

* `torch.compile` and `eager` benchmarks for `softmax` (NVIDIA#1670)

Adds `torch.compile` and `eager` baseline benchmarks to be used in
weekly benchmark runs.
Issue NVIDIA#1668.

* Add a test for fusions with no inputs. (NVIDIA#1709)

As a follow up to
NVIDIA#1696 (comment).

* Double the size of the fusion cache to workaround a CI issue. (NVIDIA#1702)

By just removing entries when it fills up.

* Check that the reduced axis is sharded on producer in isLowerableToCommunication (NVIDIA#1695)

Currently, a reduction is lowerable to a communication iff only one axis
is reduced and this axis is sharded across devices on the **producer**
side.
Before this patch, we would mistakenly check that the axis is sharded on
**consumer** side, which led to some runtime assert error.

* Add blank impl of isLowerableToCommunication. (NVIDIA#1698)

isLowerableToCommunication is used in a few places to print error
messages or short-circuit loops. Those places appear to be places that
are intended to largely be used behind the distributed path. It's easier
to just define the API instead of trying to conditionalize all the use
sites and invent non-USE_DISTRIBUTED behavior.

* Multidevice segmenter (NVIDIA#1696)

# What
Add an option in the segmenter to segment resharding Expr in separate
singleton segment.
To trigger it, set the segmenter's options as follows:
```
    SegmentCandidateFinderOptions options{
        .run_translate_welford = false,
        .run_combine_reductions = false,
        .run_herrmann_merge = true,
        .run_final_merge = true,
        .only_segment_resharding_exprs = true};
```
and use the segmenter as follows with any (possibly dummy) inputs:
```
KernelArgumentHolder dummy_inputs;
auto segmented_fusion = SegmentCandidateFinder::segment(std::move(fusion), dummy_inputs, options);
```
If `only_segment_resharding_exprs` is set to `false` (which is the case
by default), the behavior of the segmenter is unchanged.


We also provide a quite wide testing suite to validate our
implementation.

# Why 
Resharding Exprs need to be handled differently than other Exprs because
we want them to result in posting a network collective from the host.
Therefore those expressions cannot (for now) be fused to any kernel. For
this reason, we need those Expr to be segmented before and after.

# How
_**Remark:** For now, the segmenter is only used [at one place before
scheduling and compiling the
fusion](https://github.com/NVIDIA/Fuser/blob/1603f39bab8c1bbe12e38f2b5de53dec3b7cc373/csrc/kernel_cache.cpp#L990)._

Recall that the segmenter first creates as many segments as there are
Expr and then tries to merge the neighbour segments incrementally in an
eager manner. The method
```
bool SegmentCandidateFinder::codeGenSupportedMerge(
    SegmentedGroup* group1,
    SegmentedGroup* group2) 
```
returns whether two groups can be merged (i.e. fused into one kernel). 

With the current patch, if
`SegmentCandidateFinderOptions::only_segment_resharding_exprs` is set to
`true`, then the usual behavior of `codeGenSupportedMerge` is bypassed
and the function returns whether one Expr among the groups is
resharding.

Because this segmentation shouldn't depend on the inputs data, we use
default (aka empty) `KernelArgumentHolder`, from which it is invalid to
instantiate a `SchedulerRuntimeInfo runtime_info_`. For this reason, we
had to make the latter attribute optional.

# Future/other directions

Another way to achieve the same result is to manually add segment bounds
surrounding the resharding Exprs as was suggested by @wujingyue here
NVIDIA#1571

The current implementation looks a bit "hacky" and should be be
integrated more properly once multidevice schedulers are implemented
and/or the segmenter is refactored.

Later, we might wanna be able to fuse communications and computes and
also communications between them. This would require a more advanced
segmenter and scheduler, but hopefully this patch could serve as a good
basis

# Example:
consider the fusion:
```
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());

  TensorView* tv0 = makeContigTensor({4});
  fusion->addInput(tv0);
  TensorView* tv1 = sum(tv0,{3});
  TensorView* tv2 = set(tv1);
  TensorView* tv3 = sum(tv2, {2});
  fusion->addOutput(tv3);
```

Manually scheduled as follows:
```
  DeviceMesh mesh ({0,1,2,3})
  for (auto tv : {tv0, tv1, tv2, tv3}) {
    tv->setDeviceMesh(mesh);
  }
  tv0->axis(0)->parallelize(ParallelType::DIDx);
  tv1->axis(0)->parallelize(ParallelType::DIDx);
```
This scheduling implies that
- `tv0` and `tv1` are fully sharded on the devices {0,1,2,3}
- `tv2` and `tv3` are fully replicated on those same devices
- consequently, the "set" operation on the line `tv2 = set(tv1)`
actually embedds an "AllGather" network collective. This Expr is
resharding while all the other exprs are not. We thus excpect this
expression to constitute an unmergeable segment.

The segmenter in this situation with the
option`SegmentCandidateFinderOptions::only_segment_resharding_exprs` set
to `true` will result in three segments:
- Compute segment 1: with the expr `tv1 = sum(tv0,{3})`
- Communication segment 1:  with the expr `tv2 = set(tv1)`
- Compute segment 2: with the expr `tv3 = sum(tv2, {2})`

* Vectorization Factor patch for computeInfoC2P with Broadcast in mapped IterDomain (NVIDIA#1625)

Fixes NVIDIA#1567

This PR patches vectorization factor in
`ContiguousInnerDimensionsMapper::computeInfoC2P`.

Handling of resolved broadcast dimension should be made on mapped
consumer tensors' from_ids, instead of the root_domain order. Added a
few tests per @zasdfgbnm 's suggestion:

```
Case 0:
T2[1024, 2, 512] = T0[1024, 2, 1] + T1[1024, 2, 512]
allocation = rfactor
--> T0 has no vectorization

Case 1:
T2[1024, 512, 2] = T0[1024, 1, 2] + T1[1024, 512, 2]
allocation = rfactor
--> T0 has vectorization 2

Case 2:
T2[1024, 512, 2] = T0[1024, 1, 2] + T1[1024, 512, 2];
T3[512, 1024, 2] = transpose(T2[1024, 512, 2])
allocation = rfactor
*except T1 has stride_order {1, 2, 0}
--> T0 has vectorization 4

Case 3:
T2[512, 1024, 2] = T0[1, 1024, 2] + T1[512, 1024, 2]
T3[1024, 512, 2] = transpose(T2[512, 1024, 2])
allocation = rfactor
--> T0 has vectorization 2
```

---------

Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>

* transpose scheduler fix: reduction IterDomain on input tensors (NVIDIA#1661)

Fixes NVIDIA#1659 

Reorders reduction IterDomain so it won't interfere with
scheduling tiling from transpose scheduler.

* Convert reduction of expanded dims to squeeze (NVIDIA#1679)

See comment in arith.cpp for details.

One controversial change here is to allow squeezing expanded dimensions,
both in our IR's `SqueezeOp` and in the user-facing functions `squeeze`.
This results in actually removing those dimensions. This behavior
diverges from PyTorch, whose `squeeze` command will ignore requested
squeezes if the size is not 1 regardless of whether that dimension is
expanded. I'm happy to discuss this change and potentially take another
course, but I think we do need to be able to remove expanded axes (see
NVIDIA#1174 (comment) for
another case where I encountered this limitation).

Fixes NVIDIA#1678

* Make sure ValGraphs are created deterministically (NVIDIA#1714)

While I was working on NVIDIA#32, I sometimes saw non-deterministic results.
Hope this is the only source of non-determinism.

* Fix squeeze-related errors (NVIDIA#1717)

This fixes current failures in `pytest_ops.py -k squeeze` and some
integration failues.

This restores our previous semantics for squeeze, which **do not match
PyTorch**. Namely, if squeeze is provided a dimension that cannot be
squeezed, we will always raise an error.

* NVFUSER_DISTRIBUTED instead of USE_DISTRIBUTED (NVIDIA#1711)

* Add the missing `clang-format on` and reformat. (NVIDIA#1722)

* Print a newline before the header. (NVIDIA#1720)

* Associate each fusion cache with its local rank in distributed setting. (NVIDIA#1699)

### Problem:
Currently, automatic serialization saves a single cache regardless of
the number of devices. In a distributed setting, each process restores
its fusion cache from the same common workspace. However, this workspace
only contains the CUDA kernels for a single device. The remaining
processes must recompile the kernels for their devices.

### Solution:
A separate process is created for each device with `ddp` or `fsdp` and
each process contains a separate `FusionCache`. This PR associates each
fusion cache with its local rank in a distributed setting, allowing
automatic serialization to create a separate workspace for each device.
During deserialization, each process loads the workspace associated with
its local rank.

* Vectorized serial grid reduction (NVIDIA#1528)

This change allows us to use vectorized loads/stores in
`serialReductionStep`. The generated kernel now looks like
```c++
  NVFUSER_UPDATE_MAGIC_ZERO;                                        
  grid_sync::blockSerializeWait<false, false, true>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);
  #pragma unroll                                                                                                                         
  for(nvfuser_index_t i16 = 0; i16 < 4LL; ++i16) {                                                                                           nvfuser_index_t i17;                                                                                                                 
    i17 = 32LL * i16;                                                                                                                        nvfuser_index_t i18;                                                                                                                 
    i18 = 4096LL * i16;                                                                                                                  
    nvfuser_index_t i19;                                                                                                                 
    i19 = i5 + i18;                                                                                                                      
    nvfuser_index_t i20;                                                                                                                 
    i20 = -i18;                                                                                                                          
    #pragma unroll                                                                                                                       
    for(nvfuser_index_t i21 = 0; i21 < 8LL; ++i21) {                                                                                     
      nvfuser_index_t i22;                                                                                                               
      i22 = 512LL * (i21 + nvfuser_zero);                                                                                                
      Array<float, 4LL, 4> T3;                                                                                                           
      T3.set(float(0.000000000e+00f));                                                                                                   
      reduction::serialReductionStep</*vec_size=*/4>(                                                                                    
        &T3[0LL],                                                                                                                        
        &T2[(i17 + (4LL * i21))],                                                                                                        
        0.000000000e+00f,                                                                                                                
        &T6[(i19 + i22)],                                                                                                                
        [](float &a, float b) { a = a + b; },                                                                                            
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == 0,
        index_utils::maskedOffset<false, false, true>(blockIdx, gridDim) == index_utils::maskedSize<false, false, true>(gridDim) - 1,
        true,                                                                                                                                    true);                                                                                                                           
      if ((b7 && (i6 < (i20 - i22)))) {                                                                                                  
        loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T1[(i19 + i22)], &T3[0LL]);                                    
      }                                                                                                                                  
    }                                                                                                                                    
  }                                                                                                                                      
  grid_sync::blockSerializeRelease<false, false, true>(&T5[index_utils::maskedOffset<true, true, false>(blockIdx, gridDim)]);            
  NVFUSER_UPDATE_MAGIC_ZERO;       
```

* removing out-dated assert on python API (NVIDIA#1724)

removing out-dated asserts in python API `define_vector`;
adding a tests verifying the behavior

* make ci green again (NVIDIA#1730)

skip failing test.

Please enable it once we patch NVIDIA#1728

* Remove unnecessary `MATCHER_P`. (NVIDIA#1729)

* Fix Issue NVIDIA#1734 (NVIDIA#1735)

Closes Issue NVIDIA#1734

* Rename `AliasType` -> `AllocationType` (NVIDIA#1732)

* Skip executing a kernel if it's empty. (NVIDIA#1723)

I could change `compileFusion` to skip compilation as well. It turned
out to be more complicated than I expected, so I took the easier route
to skip just execution, which is at least an incremental improvement.

* don't cache slice input tv (NVIDIA#1705)

If the input tv is used by slice, don't cache it.
Fix NVIDIA#1697

* Make `MmaOp::evaluate` return output of the same dtype as `MmaOp` (NVIDIA#1733)

* Turing/Ampere Mma tests without `BroadcastOp` (NVIDIA#1672)

This PR renames `matmulAtInput` into `matmulAtInput2D`, explicitly
showing that it generates 2D inputs. This PR also adds a
`matmulAtInput3DTuring`, which is used to generate the 3D fusion inputs
(for example `[M, 1, K]` and `[1, K, N]`) for matmul. The `MmaTest` for
Turing and Ampere is modified to exclude the `BroadcastOp` and use the
3D version for generating fusion inputs. This is only the initial step
for making `scheduleMatmul` schedule a fusion not containing
`BroadcastOp`, I intentionally keep it small. Other changes will be
added in followup PRs.

Fixes NVIDIA#1628

* io_alias_ const update (NVIDIA#1740)

* Add benchmarks for RoPE. (NVIDIA#1739)

This PR adds two implementations of the RoPE module and benchmarks them
for NVIDIA#1597.

`rope_with_cat_fusion` mimics the Hugging Face implementation.
`rope_without_cat_fusion` implements an idea from @nikitaved to avoid
concatenation. Even though it looks difficult for the compiler to do it
all automatically, it's still useful to keep a record of the idea.

As a side change, I made `fd.define_tensor` to accept empty contiguity.

* Make nvfuser matmul benchmarks HSH instead of HSS (NVIDIA#1712)

This matches the `at::matmul` baselines.

This PR also adds a few more problem sizes, and runs each eagermode
baseline with and without FP16 reduction allowed.

* Reduce number of `MmaTest`s (NVIDIA#1738)

This PR is stacked on top of NVIDIA#1672

Turing/Ampere mma is only TN, so it makes no sense to test other layouts
in `MmaTest`s. These tests are intended to test mma instructions,
`ldmatrix` and `ldmatrix.trans` is tested separately in other unit
tests. Similar for `HopperRS` tests.

* Weekly Benchmarks Input Range (NVIDIA#1708)

* Rename axes= to dims= in frontend (NVIDIA#1741)

Currently we accept `axes=` for some ops like `fd.ops.sum` and `dims=`
for others like `fd.ops.squeeze`.

This is a small attempt to make the frontend arguments more consistent.
This change renames the `axis=` kwarg to `dim=` and the same for `axes=`
-> `dims=`.

I think we're free to set our own convention, but for reference:
- PyTorch uses `dim=` in most places and accepts either a single dim or
multiple using that same argument name, where applicable.
- Numpy uses `axis=` and, like PyTorch, accepts a list where applicable.
- `jax.lax` uses `dimensions=`

* Avoid unused smem workspace for serial grid reductions (NVIDIA#1727)

GridReduction can be lowered to either `gridReduce` or
`serialReductionStep`. `gridReduce` requires a smem workspace in order
to use multiple threads to aggregate partial sums. However,
`serialReductionStep` does not coordinate among threads and has no use
for a workspace. This change simply disables allocating that little bit
of extra shared memory if our only grid reductions are serial, which
currently only happens in split-K GEMM.

This reduces the smem allocated in a simple test from 16896 B to 16384 B
(about 97%). More importantly, this makes the computation in
`mma_utils::generateSharedMemoryEpilogueHeuristics()` more accurate.
Tests are updated to check that this computation is accurate.

The change in `kernel.cpp` is responsible for reducing actual smem usage
for split-K. The changes to `mma_utils` and `test_gpu_tensorcore.cpp`
are needed for adding testing that our expected smem usage matches the
actual usage.

* Issue NVIDIA#1748 (NVIDIA#1749)

Closes Issue NVIDIA#1748.
Apart from `c10::cuda::GetDevice`, no other functionality seems
affected.

* Rename `axes` to `dims` in benchmarks fusion definitions (NVIDIA#1751)

Changes the kwarg `axes` to `dims` following the API change in PR NVIDIA#1741.

* Bump matmul benchmark checkMatch() tolerance (NVIDIA#1747)

This is necessary due to recent switch to HSH

Fixes NVIDIA#1746

* linter

* change guard USE_DISTRIBUTED to NVFUSER_DISTRIBUTED in test/test_multidevice_sharding.cpp

* linting

* linter and cleanup

* remove allocator.h/cpp files

* Device index patch (NVIDIA#1752)

Fixes NVIDIA#1748 

guard c10::cuda::GetDevice API change on TORCH_VERSION

with this change, it ensures that we can build against stable release `<
2.2.0`, as well as TOT after
pytorch/pytorch#119142

For 2.3.0 nightly, if someone accidentally checkout a commit before the
patch, the build will still fail.

* fixing multidevice build (NVIDIA#1753)

API change coming from pytorch/pytorch#119421

* patching API GUARD (NVIDIA#1754)

patching API version guard so we'll still be able to build against older
pytorch version.

* Add a visitor for ValGraph (NVIDIA#1713)

Used in the loop promotion analysis. Extracted from NVIDIA#32

* empty commit for triggering CI

---------

Co-authored-by: Liqiang Lu <116412316+liqiangxl@users.noreply.github.com>
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com>
Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
Co-authored-by: Tom Fogal <60981+tfogal@users.noreply.github.com>
Co-authored-by: jjsjann123 <jiej@nvidia.com>
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
Co-authored-by: Naoya Maruyama <naoyam@users.noreply.github.com>
Co-authored-by: Meghan Cowan <mcowan@nvidia.com>
Co-authored-by: Ryan Spring <rspring@nvidia.com>
zasdfgbnm added a commit that referenced this issue Feb 20, 2024
@zasdfgbnm
Copy link
Collaborator Author

This issue is closed by #1736. I will not reopen this because the main purpose of this issue is tracking the Hopper matmul issue, which is already fixed in #1736 by modifying the test. The segmenter issue is tracked in #1707, and is not resolved yet.

zasdfgbnm added a commit that referenced this issue Feb 21, 2024
Partially fixes: #1628

Make `scheduleMatmul` capable of handling inputs like `[M, 1, K] x [1,
N, K]`.

---------

Co-authored-by: Andrzej Bekas <118676880+drzejan2@users.noreply.github.com>
zasdfgbnm added a commit that referenced this issue Feb 21, 2024
Stacked on #1776, just migrating
more tests.
Partially fixes: #1628

---------

Co-authored-by: Andrzej Bekas <118676880+drzejan2@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants