Skip to content

CuTe DSL: work around SM120 rank-2 TMA cute.copy hang#3189

Closed
alecco wants to merge 3 commits into
NVIDIA:mainfrom
alecco:sm120-direct-tma-wrapper-pr
Closed

CuTe DSL: work around SM120 rank-2 TMA cute.copy hang#3189
alecco wants to merge 3 commits into
NVIDIA:mainfrom
alecco:sm120-direct-tma-wrapper-pr

Conversation

@alecco
Copy link
Copy Markdown

@alecco alecco commented Apr 28, 2026

Author: Alecco (& Codex) for Ologan

Summary

On SM120/SM120a, CuTe DSL rank-2 G2S descriptor TMA through the normal executable cute.copy path can compile and launch, but then hang while waiting on the TMA pipeline mbarrier.

A minimal failing shape is a CTA-local rank-2 TMA load over a direct-basis (d, seq) FP32 tensor:

cute.copy(
    tma_atom,
    tma_g,
    tma_s,
    tma_bar_ptr=pipe.producer_get_barrier(producer_state),
)

The kernel reaches launch, but torch.cuda.synchronize() does not return. A repro is included at the bottom of this PR.

This PR does not fix the underlying _cute_nvgpu executable TMA lowering. Instead, it adds a narrow SM120 direct TMA issue path that uses the same CuTe-generated descriptor but bypasses the failing cute.copy lowering. It also documents and validates the contract for using that path with swizzled shared-memory destinations.

Problem

The reduced diagnosis was:

Direct-basis CuTe descriptor + raw SM120 TMA issue + PipelineTmaAsync: pass
Driver API descriptor + raw SM120 TMA issue + PipelineTmaAsync: pass
Direct-basis CuTe descriptor + normal cute.copy/PipelineTmaAsync: timeout
Driver API descriptor + normal cute.copy/PipelineTmaAsync: timeout
FA-style canonicalize/logicalize descriptor + raw issue: coord1 ignored

This points to two separate issues:

  1. normal executable rank-2 TMA cute.copy lowering is not currently safe for this SM120 path;
  2. the earlier canonicalize-then-logicalize approach for FlashAttention-style (seq, d) tensors can construct a descriptor path where the second coordinate is not represented correctly.

The useful working shape is direct TMA basis from the start:

logical GMEM:      (seq, d)
TMA descriptor:    (d, seq)
TMA strides:       (1, D)
TMA coordinates:   {d_coord, seq_coord}

For swizzled shared memory, the same direct-basis rule applies. The TMA atom must be built with the intended composed SMEM layout, and the destination pointer passed to the direct issue helper must carry the matching swizzle.

Workaround Added

This PR adds an explicit SM120/SM120a rank-2 TMA issue helper under:

cutlass.cute.nvgpu.cpasync

The intended path is:

gmem_tma = cute.make_tensor(
    gmem.iterator,
    cute.make_layout((D, S), stride=(1, D)),
)

tma_atom, tma_tensor, desc_ptr = cpasync.make_sm120_tma_load_2d_atom(
    gmem_tma,
    smem_layout_tma_basis,
    cta_tiler_tma_basis,
)

cpasync.sm120_tma_load_2d(
    dst_smem_ptr,
    desc_ptr,
    pipe.producer_get_barrier(producer_state),
    d_coord,
    seq_coord,
)

This keeps descriptor construction in CuTe DSL, but avoids the failing executable cute.copy TMA issue path.

For swizzled SMEM, callers should use a matching layout and pointer contract. Conceptually:

swizzle = cute.make_swizzle(3, 4, 3)

smem_layout_tma_basis = cute.make_composed_layout(
    swizzle,
    0,
    base_smem_layout_tma_basis,
)

sT = smem.allocate_tensor(
    dtype,
    base_smem_layout_tma_basis,
    byte_alignment=128,
    swizzle=swizzle,
)

tma_atom, tma_tensor, desc_ptr = cpasync.make_sm120_tma_load_2d_atom(
    gmem_tma,
    smem_layout_tma_basis,
    cta_tiler_tma_basis,
)

cpasync.sm120_tma_load_2d(
    sT.iterator,
    desc_ptr,
    pipe.producer_get_barrier(producer_state),
    d_coord,
    seq_coord,
)

The key requirement is that the descriptor’s SMEM layout and the destination SMEM pointer agree about the swizzle. The smoke test reads back through the same swizzled CuTe tensor view.

New API

cpasync.get_tma_desc_addr(tma_atom)

Returns the tiled TMA descriptor address associated with a TMA copy atom.

cpasync.sm120_tma_load_2d(
    dst_smem_ptr,
    tma_desc_ptr,
    tma_bar_ptr,
    coord0,
    coord1,
    *,
    cache_policy=None,
    already_elected=False,
    tile_mode=False,
)

Issues a CTA-local SM120 rank-2 TMA load. By default it performs warp election internally, so callers do not accidentally issue multiple TMA transactions for one mbarrier phase.

cpasync.make_sm120_tma_load_2d_atom(
    gmem_tensor,
    smem_layout,
    cta_tiler,
    *,
    internal_type=None,
)

Builds a narrow direct-basis rank-2 SM120 TMA load atom and returns:

(tma_atom, tma_tensor, desc_ptr)

This helper deliberately does not canonicalize or logicalize modes. Callers must pass a direct-basis tensor where mode 0 is physically contiguous.

Validation

Added a standalone SM120 smoke script:

CUTE_DSL_ARCH=sm_120 python test/examples/CuTeDSL/sm_120a/sm120_direct_tma_smoke.py

It validates:

CuTe-generated direct-basis rank-2 descriptors
cpasync.get_tma_desc_addr(...)
cpasync.sm120_tma_load_2d(...)
PipelineTmaAsync
FP32 and BF16 coordinate sweeps with nonzero {d, seq} coordinates
.tile + .L2::cache_hint instruction spelling
FA-like K/V direct-basis loads for FP16/BF16 with D = 64, 96, 128 and seq_tile = 64, 128
SW128 swizzled SMEM load/readback for FP16/BF16 64x64 tiles

Local result:

SM120 direct TMA smoke passed

Also checked syntax with:

python -m py_compile \
  python/CuTeDSL/cutlass/cute/nvgpu/cpasync/helpers.py \
  python/CuTeDSL/cutlass/cute/nvgpu/cpasync/__init__.py \
  test/examples/CuTeDSL/sm_120a/sm120_direct_tma_smoke.py

The swizzled coverage is intentionally conservative. The local SM120 setup has a 99 KiB shared-memory limit, so the committed smoke keeps individual tested tiles below a
64 KiB budget:

FP32 64x64              = 16 KiB
BF16/FP16 64x64         = 8 KiB
BF16/FP16 96x128        = 24 KiB
BF16/FP16 128x128       = 32 KiB
BF16/FP16 64x64 swizzled = 8 KiB

Larger swizzled shapes are not claimed by this PR and should be validated separately.

Scope

This is intentionally narrow:

supported: SM120/SM120a CTA-local rank-2 G2S TMA
supported: direct-basis descriptors where mode 0 is physically contiguous
supported: optional swizzled SMEM when descriptor layout and destination pointer swizzle match
not added: multicast
not added: CTA group 2
not added: arbitrary ND TMA
not added: automatic (seq, d) -> (d, seq) canonicalize/logicalize routing
not changed: generic cute.copy TMA lowering

The underlying _cute_nvgpu executable TMA lowering should still be fixed separately. This PR gives CuTe DSL users a tested SM120 rank-2 path in the meantime, including the SMEM swizzle contract needed for optimized shared-memory staging.

Agent disclosure

This work was created with OpenAI Codex CLI agent, but directed, supervised, and every line reviewed by a human.

Issue repro

#!/usr/bin/env python3
"""Minimal SM120 rank-2 TMA `cute.copy` timeout repro.

Run from the CUTLASS repo root:

    CUTE_DSL_ARCH=sm_120 timeout 90s python agent_space/sm120_cute_copy_tma_timeout_repro.py

Expected affected behavior:

    launching SM120 rank-2 TMA cute.copy repro; expected failure is timeout...
    # process hangs in torch.cuda.synchronize(), then timeout exits with status 124
"""

import os

import torch

import cutlass
import cutlass.cute as cute
import cutlass.pipeline
import cutlass.utils
from cutlass.cute.nvgpu import cpasync
from cutlass.cute.runtime import from_dlpack
from cutlass.pipeline.sm90 import PipelineTmaAsync, make_pipeline_state


@cute.kernel
def _copy1d_kernel(tma_atom: cute.CopyAtom, tma_tensor: cute.Tensor, out: cute.Tensor):
    smem = cutlass.utils.SmemAllocator()
    smem_tile = smem.allocate_tensor(cutlass.Float32, cute.make_layout(64))
    mbar = smem.allocate_tensor(cutlass.Int64, cute.make_layout(2), byte_alignment=8)

    tidx, _, _ = cute.arch.thread_idx()
    warp = tidx // 32
    lane = tidx % 32

    pipe = PipelineTmaAsync.create(
        barrier_storage=mbar.iterator,
        num_stages=1,
        producer_group=cutlass.pipeline.CooperativeGroup(
            cutlass.pipeline.Agent.Thread, 1
        ),
        consumer_group=cutlass.pipeline.CooperativeGroup(
            cutlass.pipeline.Agent.Thread, 1
        ),
        tx_count=64 * cutlass.Float32.width // 8,
        defer_sync=False,
    )

    g_tile = cute.zipped_divide(tma_tensor, (64,))[(None,), 1]
    g_tile = cute.group_modes(g_tile, 0, cute.rank(g_tile))
    tma_s, tma_g = cpasync.tma_partition(
        tma_atom,
        cutlass.Int32(0),
        cute.make_layout(1),
        cute.group_modes(smem_tile, 0, cute.rank(smem_tile)),
        g_tile,
    )

    if warp == 0:
        with cute.arch.elect_one():
            cpasync.prefetch_descriptor(tma_atom)

        producer_state = make_pipeline_state(
            cutlass.pipeline.PipelineUserType.Producer, 1
        )
        pipe.producer_acquire(producer_state)
        with cute.arch.elect_one():
            cute.copy(
                tma_atom,
                tma_g,
                tma_s,
                tma_bar_ptr=pipe.producer_get_barrier(producer_state),
            )
        producer_state.advance()
        pipe.producer_tail(producer_state)

    if warp == 1:
        consumer_state = make_pipeline_state(
            cutlass.pipeline.PipelineUserType.Consumer, 1
        )
        pipe.consumer_wait(consumer_state)
        cute.arch.fence_view_async_shared()
        pipe.consumer_release(consumer_state)
        out[lane] = smem_tile[lane]


@cute.kernel
def _copy2d_kernel(tma_atom: cute.CopyAtom, tma_tensor: cute.Tensor):
    smem = cutlass.utils.SmemAllocator()
    smem_tile = smem.allocate_tensor(cutlass.Float32, cute.make_layout((64, 64)))
    mbar = smem.allocate_tensor(cutlass.Int64, cute.make_layout(2), byte_alignment=8)

    tidx, _, _ = cute.arch.thread_idx()
    warp = tidx // 32

    pipe = PipelineTmaAsync.create(
        barrier_storage=mbar.iterator,
        num_stages=1,
        producer_group=cutlass.pipeline.CooperativeGroup(
            cutlass.pipeline.Agent.Thread, 1
        ),
        consumer_group=cutlass.pipeline.CooperativeGroup(
            cutlass.pipeline.Agent.Thread, 1
        ),
        tx_count=64 * 64 * cutlass.Float32.width // 8,
        defer_sync=False,
    )

    g_tile = cute.local_tile(tma_tensor, (64, 64), (0, 1))
    tma_s, tma_g = cpasync.tma_partition(
        tma_atom,
        cutlass.Int32(0),
        cute.make_layout(1),
        cute.group_modes(smem_tile, 0, 2),
        cute.group_modes(g_tile, 0, 2),
    )

    if warp == 0:
        with cute.arch.elect_one():
            cpasync.prefetch_descriptor(tma_atom)

        producer_state = make_pipeline_state(
            cutlass.pipeline.PipelineUserType.Producer, 1
        )
        pipe.producer_acquire(producer_state)
        with cute.arch.elect_one():
            cute.copy(
                tma_atom,
                tma_g,
                tma_s,
                tma_bar_ptr=pipe.producer_get_barrier(producer_state),
            )
        producer_state.advance()
        pipe.producer_tail(producer_state)

    if warp == 1:
        consumer_state = make_pipeline_state(
            cutlass.pipeline.PipelineUserType.Consumer, 1
        )
        pipe.consumer_wait(consumer_state)
        cute.arch.fence_view_async_shared()
        pipe.consumer_release(consumer_state)


@cute.jit
def _launch_1d(src: cute.Tensor, out: cute.Tensor):
    tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
        cpasync.CopyBulkTensorTileG2SOp(),
        src,
        cute.make_layout((64,)),
        (64,),
        1,
    )
    _copy1d_kernel(tma_atom, tma_tensor, out).launch(
        grid=[1, 1, 1],
        block=[64, 1, 1],
        smem=64 * cutlass.Float32.width // 8 + 16,
    )


@cute.jit
def _launch_2d(src: cute.Tensor):
    direct_src = cute.make_tensor(
        src.iterator,
        cute.make_layout((128, 128), stride=(1, 128)),
    )
    tma_atom, tma_tensor = cpasync.make_tiled_tma_atom(
        cpasync.CopyBulkTensorTileG2SOp(),
        direct_src,
        cute.make_layout((64, 64)),
        (64, 64),
        1,
    )
    _copy2d_kernel(tma_atom, tma_tensor).launch(
        grid=[1, 1, 1],
        block=[64, 1, 1],
        smem=64 * 64 * cutlass.Float32.width // 8 + 16,
    )


def main():
    os.environ.setdefault("CUTE_DSL_ARCH", "sm_120")
    assert torch.cuda.is_available(), "CUDA is required"
    major, minor = torch.cuda.get_device_capability()
    assert major >= 12, f"SM120/SM120a required, got SM{major}{minor}"

    src_1d = torch.arange(128, device="cuda", dtype=torch.float32)
    out_1d = torch.empty(32, device="cuda", dtype=torch.float32)
    args_1d = (
        from_dlpack(src_1d, assumed_align=16),
        from_dlpack(out_1d, assumed_align=16),
    )

    print("compiling rank-1 TMA cute.copy control...", flush=True)
    cute.compile(_launch_1d, *args_1d)
    print("rank-1 control compiled", flush=True)

    src_2d = torch.arange(128 * 128, device="cuda", dtype=torch.float32).reshape(
        128, 128
    )
    args_2d = (from_dlpack(src_2d, assumed_align=16),)

    print(
        "launching SM120 rank-2 TMA cute.copy repro; expected failure is timeout...",
        flush=True,
    )
    cute.compile(_launch_2d, *args_2d)(*args_2d)

    # Affected behavior: hangs here until the shell-level `timeout` kills it.
    torch.cuda.synchronize()
    print("unexpectedly completed")


if __name__ == "__main__":
    main()

The SM120 direct TMA load helper forwards composed SMEM layouts to CuTe's TMA atom builder. Document the required contract for that path: when the atom is built with a swizzled SMEM layout, the destination SMEM pointer passed to sm120_tma_load_2d must carry the same swizzle.

Extend the SM120 direct TMA smoke to exercise this contract with SW128 swizzled shared memory. The smoke constructs the descriptor with a composed SW128 layout, allocates the destination tensor with the matching swizzled pointer, issues the direct rank-2 TMA load, and reads back through the same swizzled CuTe tensor view.

The default smoke keeps swizzled coverage to 64x64 FP16/BF16 tiles so the diagnostic remains well below RTX 50 / SM120's 99 KiB shared-memory limit and avoids the larger swizzled shapes that still need separate backend investigation.
alecco pushed a commit to alecco/quack that referenced this pull request Apr 28, 2026
Add narrow QuACK wrappers around the SM120 rank-2 direct TMA workaround introduced in NVIDIA/cutlass#3189 ("CuTe DSL: work around SM120 rank-2 TMA cute.copy hang").

The local CuTe DSL workaround keeps CuTe descriptor construction, but bypasses the currently problematic SM120 rank-2 cute.copy issue path by exposing a direct CTA-local TMA load helper. This commit adds QuACK-side helpers for that path:

- feature checks for the required CuTe DSL cpasync helpers

- explicit row-major [seq, d] -> TMA-basis (d, seq) tensor construction

- direct rank-2 TMA atom construction

- descriptor address access

- CTA-local direct TMA load issue

This is intentionally not wired into GemmSm120 yet. GemmSm120 is under active development, and this keeps the new functionality opt-in and low-conflict. The helper is meant as reusable scaffolding for future SM120 FlashAttention-style K/V load experiments where direct TMA can stage large dense tiles while other warps do useful work.

Add copy-focused validation plus a benchmark for comparing direct TMA against two non-TMA baselines: a simple producer-warp cp.async path and a cooperative blocking copy path. The benchmark has an FA-like overlap scenario with two consumer models:

- mma: default synthetic BF16/FP16 Tensor Core work using SM120 warp-level MmaF16BF16Op, so staged K/V-like tiles feed ldmatrix and cute.gemm work

- scalar: diagnostic shared-memory read and FP32 accumulation work for isolating staging overhead

The benchmark is not intended to represent full GEMM or full FlashAttention performance. It is a focused tool for checking whether the direct TMA workaround is usable and for studying where TMA becomes worthwhile: generally larger tile transfers, enough independent consumer work to hide pipeline wait, and enough care around shared-memory footprint. The docs include Nsight Compute commands and note that workstation timing noise should be interpreted with sweeps and counters rather than single rows.

Current GEMM behavior is unchanged.
alecco pushed a commit to alecco/quack that referenced this pull request Apr 29, 2026
Add narrow QuACK wrappers around the SM120 rank-2 direct TMA workaround introduced in NVIDIA/cutlass#3189 ("CuTe DSL: work around SM120 rank-2 TMA cute.copy hang").

The local CuTe DSL workaround keeps CuTe descriptor construction, but bypasses the currently problematic SM120 rank-2 cute.copy issue path by exposing a direct CTA-local TMA load helper. This commit adds QuACK-side helpers for that path:

- feature checks for the required CuTe DSL cpasync helpers

- explicit row-major [seq, d] -> TMA-basis (d, seq) tensor construction

- direct rank-2 TMA atom construction

- descriptor address access

- CTA-local direct TMA load issue

This is intentionally not wired into GemmSm120 yet. GemmSm120 is under active development, and this keeps the new functionality opt-in and low-conflict. The helper is meant as reusable scaffolding for future SM120 FlashAttention-style K/V load experiments where direct TMA can stage large dense tiles while other warps do useful work.

Add copy-focused validation plus a benchmark for comparing direct TMA against two non-TMA baselines: a simple producer-warp cp.async path and a cooperative blocking copy path. The benchmark has an FA-like overlap scenario with two consumer models:

- mma: default synthetic BF16/FP16 Tensor Core work using SM120 warp-level MmaF16BF16Op, so staged K/V-like tiles feed ldmatrix and cute.gemm work

- scalar: diagnostic shared-memory read and FP32 accumulation work for isolating staging overhead

The benchmark is not intended to represent full GEMM or full FlashAttention performance. It is a focused tool for checking whether the direct TMA workaround is usable and for studying where TMA becomes worthwhile: generally larger tile transfers, enough independent consumer work to hide pipeline wait, and enough care around shared-memory footprint. The docs include Nsight Compute commands and note that workstation timing noise should be interpreted with sweeps and counters rather than single rows.

Current GEMM behavior is unchanged.
@alecco
Copy link
Copy Markdown
Author

alecco commented Apr 29, 2026

Closing PR.

I was not grouping (tile_M, tile_K).

Sorry for the noise, and thanks for the patience.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant