-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Open
Labels
Description
What is your question?
Hi, I have two questions related to predication and out-of-bound access.
In the following code, thread 300 is predicated out (pred = 0 in the printout), yet it still leads to an out-of-bounds access. Is this expected behavior?
import time
import cutlass.cute as cute
import cutlass
import torch
# Initialize CUDA context
torch.rand(1, device="cuda")
@cute.kernel
def kernel(gmem: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
allocator = cutlass.utils.SmemAllocator()
# Allocate shared memory for 192 float16 entries
smem = allocator.allocate_tensor(
element_type=cute.Float16,
layout=cute.make_layout((192,)),
byte_alignment=16,
swizzle=None,
)
# Create copy atom for 32-bit async copy (2 float16 elements per copy)
copy_atom = cute.make_copy_atom(
op=cute.nvgpu.cpasync.CopyG2SOp(),
copy_internal_type=cute.Float16,
num_bits_per_copy=32,
)
# Create tiled copy: 384 threads, 2 elements per copy
tiler_mn, layout_tv = cute.make_layout_tv(
thr_layout=cute.make_layout(384),
val_layout=cute.make_layout(2),
)
tiled_copy = cute.make_tiled_copy(
atom=copy_atom,
layout_tv=layout_tv,
tiler_mn=tiler_mn,
)
thread_copy = tiled_copy.get_slice(thr_idx=tidx)
# Partition source, destination, and coordinates
crd = cute.make_identity_tensor(gmem.shape)
src_thread = thread_copy.partition_S(gmem)
dst_thread = thread_copy.partition_D(smem)
crd_thread = thread_copy.partition_S(crd)
# Create 1D predicate tensor (checked at copy atom granularity)
pred_thread = cute.make_rmem_tensor(
layout_or_shape=cute.make_layout(
(crd_thread.shape[0][1], cute.size(crd_thread, mode=[1]))
),
dtype=cutlass.Boolean,
)
for rest_v in cutlass.range(pred_thread.shape[0], unroll_full=True):
for i in cutlass.range(pred_thread.shape[1], unroll_full=True):
pred_thread[rest_v, i] = cute.elem_less(
crd_thread[(0, rest_v), i][0],
gmem.shape[0],
)
# Async copy from global to shared with predication
if tidx == 192:
cute.printf("[tid={}] pred_thread={}", tidx, pred_thread)
cute.copy(
atom=copy_atom,
src=src_thread,
dst=dst_thread,
pred=pred_thread,
)
# Wait for async copy and synchronize
cute.arch.cp_async_commit_group()
cute.arch.cp_async_wait_group(0)
cute.arch.barrier()
# Print results from first and last few threads
# if tidx < 10 or (tidx >= 190 and tidx < 192):
# val = smem[tidx] if tidx < 192 else cute.Float16(-1)
# cute.printf("[tid={}] pred={} smem[{}]={}", tidx, pred, tidx, val)
@cute.jit
def launch_kernel(gmem: cute.Tensor):
k = kernel(gmem)
k.launch(grid=(1, 1, 1), block=(384, 1, 1))
# Create global memory tensor and launch
gmem_tensor = torch.arange(192, dtype=torch.float16, device="cuda")
gmem = cute.runtime.from_dlpack(gmem_tensor, assumed_align=16)
fn = cute.compile(launch_kernel, gmem, options="--generate-line-info")
fn(gmem)
time.sleep(1)
fn(gmem)In the following function, seems like just how the if syntax is used in python can also lead to different behaviors regarding out of bound accesses.
import time
import cutlass.cute as cute
import cutlass
import torch
# Initialize CUDA context
torch.rand(1, device="cuda")
@cute.kernel
def kernel(gmem: cute.Tensor):
tidx, _, _ = cute.arch.thread_idx()
allocator = cutlass.utils.SmemAllocator()
# Allocate shared memory for 192 float16 entries
smem = allocator.allocate_tensor(
element_type=cute.Float16,
layout=cute.make_layout((192,)),
byte_alignment=16,
swizzle=None,
)
# this will lead to error
# if tidx >= 190 and tidx <= 193:
# val = smem[tidx] if tidx < 192 else cute.Float16(-1)
# cute.printf("[tid={}] pred={} smem[{}]={}", tidx, pred, tidx, val)
if tidx >= 190 and tidx <= 193:
val = cute.Float16(-1)
if tidx < 192:
val = smem[tidx]
cute.printf("[tid={}] pred={} smem[{}]={}", tidx, pred, tidx, val)
@cute.jit
def launch_kernel(gmem: cute.Tensor):
k = kernel(gmem)
k.launch(grid=(1, 1, 1), block=(384, 1, 1))
# Create global memory tensor and launch
gmem_tensor = torch.arange(192, dtype=torch.float16, device="cuda")
gmem = cute.runtime.from_dlpack(gmem_tensor, assumed_align=16)
fn = cute.compile(launch_kernel, gmem, options="--generate-line-info")
fn(gmem)
time.sleep(1)
fn(gmem)