Skip to content

[QST] [CuTeDSL] Out of bound access even with Predication #2980

@HanGuo97

Description

@HanGuo97

What is your question?

Hi, I have two questions related to predication and out-of-bound access.

In the following code, thread 300 is predicated out (pred = 0 in the printout), yet it still leads to an out-of-bounds access. Is this expected behavior?

import time
import cutlass.cute as cute
import cutlass
import torch

# Initialize CUDA context
torch.rand(1, device="cuda")


@cute.kernel
def kernel(gmem: cute.Tensor):
    tidx, _, _ = cute.arch.thread_idx()
    allocator = cutlass.utils.SmemAllocator()

    # Allocate shared memory for 192 float16 entries
    smem = allocator.allocate_tensor(
        element_type=cute.Float16,
        layout=cute.make_layout((192,)),
        byte_alignment=16,
        swizzle=None,
    )

    # Create copy atom for 32-bit async copy (2 float16 elements per copy)
    copy_atom = cute.make_copy_atom(
        op=cute.nvgpu.cpasync.CopyG2SOp(),
        copy_internal_type=cute.Float16,
        num_bits_per_copy=32,
    )

    # Create tiled copy: 384 threads, 2 elements per copy
    tiler_mn, layout_tv = cute.make_layout_tv(
        thr_layout=cute.make_layout(384),
        val_layout=cute.make_layout(2),
    )
    tiled_copy = cute.make_tiled_copy(
        atom=copy_atom,
        layout_tv=layout_tv,
        tiler_mn=tiler_mn,
    )
    thread_copy = tiled_copy.get_slice(thr_idx=tidx)

    # Partition source, destination, and coordinates
    crd = cute.make_identity_tensor(gmem.shape)
    src_thread = thread_copy.partition_S(gmem)
    dst_thread = thread_copy.partition_D(smem)
    crd_thread = thread_copy.partition_S(crd)

    # Create 1D predicate tensor (checked at copy atom granularity)
    pred_thread = cute.make_rmem_tensor(
        layout_or_shape=cute.make_layout(
            (crd_thread.shape[0][1], cute.size(crd_thread, mode=[1]))
        ),
        dtype=cutlass.Boolean,
    )
    for rest_v in cutlass.range(pred_thread.shape[0], unroll_full=True):
        for i in cutlass.range(pred_thread.shape[1], unroll_full=True):
            pred_thread[rest_v, i] = cute.elem_less(
                crd_thread[(0, rest_v), i][0],
                gmem.shape[0],
            )

    # Async copy from global to shared with predication
    if tidx == 192:
        cute.printf("[tid={}] pred_thread={}", tidx, pred_thread)
        cute.copy(
            atom=copy_atom,
            src=src_thread,
            dst=dst_thread,
            pred=pred_thread,
        )

    # Wait for async copy and synchronize
    cute.arch.cp_async_commit_group()
    cute.arch.cp_async_wait_group(0)
    cute.arch.barrier()

    # Print results from first and last few threads
    # if tidx < 10 or (tidx >= 190 and tidx < 192):
    #     val = smem[tidx] if tidx < 192 else cute.Float16(-1)
    #     cute.printf("[tid={}] pred={} smem[{}]={}", tidx, pred, tidx, val)

@cute.jit
def launch_kernel(gmem: cute.Tensor):
    k = kernel(gmem)
    k.launch(grid=(1, 1, 1), block=(384, 1, 1))


# Create global memory tensor and launch
gmem_tensor = torch.arange(192, dtype=torch.float16, device="cuda")
gmem = cute.runtime.from_dlpack(gmem_tensor, assumed_align=16)

fn = cute.compile(launch_kernel, gmem, options="--generate-line-info")
fn(gmem)
time.sleep(1)
fn(gmem)

In the following function, seems like just how the if syntax is used in python can also lead to different behaviors regarding out of bound accesses.

import time
import cutlass.cute as cute
import cutlass
import torch

# Initialize CUDA context
torch.rand(1, device="cuda")


@cute.kernel
def kernel(gmem: cute.Tensor):
    tidx, _, _ = cute.arch.thread_idx()
    allocator = cutlass.utils.SmemAllocator()

    # Allocate shared memory for 192 float16 entries
    smem = allocator.allocate_tensor(
        element_type=cute.Float16,
        layout=cute.make_layout((192,)),
        byte_alignment=16,
        swizzle=None,
    )

    # this will lead to error
    # if tidx >= 190 and tidx <= 193:
    #     val = smem[tidx] if tidx < 192 else cute.Float16(-1)
    #     cute.printf("[tid={}] pred={} smem[{}]={}", tidx, pred, tidx, val)

    if tidx >= 190 and tidx <= 193:
        val = cute.Float16(-1)
        if tidx < 192:
            val = smem[tidx]
        cute.printf("[tid={}] pred={} smem[{}]={}", tidx, pred, tidx, val)

@cute.jit
def launch_kernel(gmem: cute.Tensor):
    k = kernel(gmem)
    k.launch(grid=(1, 1, 1), block=(384, 1, 1))


# Create global memory tensor and launch
gmem_tensor = torch.arange(192, dtype=torch.float16, device="cuda")
gmem = cute.runtime.from_dlpack(gmem_tensor, assumed_align=16)

fn = cute.compile(launch_kernel, gmem, options="--generate-line-info")
fn(gmem)
time.sleep(1)
fn(gmem)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions