From 5c03690433d8bb6e9a2a47f5a3070ae1cbbd6348 Mon Sep 17 00:00:00 2001 From: Ali Hassani Date: Mon, 9 Jan 2023 18:10:52 -0500 Subject: [PATCH 1/2] Adds missing semicolon --- include/cutlass/arch/mma.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index 7385d882bf..2bcabb2d69 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -70,7 +70,7 @@ struct OpMultiplyAddFastF16 {}; /// Tag indicating the input is converted to 2 (big and small) TF32 components // Perform 3xTF32 or 4xTF32 for every F32 output element -struct OpMultiplyAddFastF32 {} +struct OpMultiplyAddFastF32 {}; /// Tag indicating the input is converted to 2 (big and small) TF32 components // Perform 3xTF32 or 4xTF32 for every complex output element From 6e701f5db085c073e51ad162df132051f73a4400 Mon Sep 17 00:00:00 2001 From: Ali Hassani Date: Tue, 10 Jan 2023 02:26:36 -0500 Subject: [PATCH 2/2] Partial specializations for cp_async_zfill --- include/cutlass/arch/memory_sm80.h | 78 ++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/include/cutlass/arch/memory_sm80.h b/include/cutlass/arch/memory_sm80.h index 8d5822e5c6..b7d16da10b 100644 --- a/include/cutlass/arch/memory_sm80.h +++ b/include/cutlass/arch/memory_sm80.h @@ -104,6 +104,40 @@ static const uint32_t OOB_NAN_F16x2 = ((OOB_NAN_F16 << 16) | OOB_NAN_F16); //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Fallback specialization for 1 byte copies +template < + /// Cache operation - ignored + CacheOperation::Kind cache_op> +struct cp_async<1, cache_op> { + + /// Copy + CUTLASS_DEVICE + cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + } +}; + +/// Fallback specialization for 2 byte copies +template < + /// Cache operation - ignored + CacheOperation::Kind cache_op> +struct cp_async<2, cache_op> { + + /// Copy + CUTLASS_DEVICE + cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) { + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + } +}; + /// Partial specialization template < /// Size of the access in bytes @@ -143,6 +177,50 @@ struct cp_async { } }; +/// Fallback specialization for 1 byte copies +template < + /// Cache operation - ignored + CacheOperation::Kind cache_op> +struct cp_async_zfill<1, cache_op> { + + /// Copy with zero fill + CUTLASS_DEVICE + cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } + } +}; + +/// Fallback specialization for 2 byte copies +template < + /// Cache operation - ignored + CacheOperation::Kind cache_op> +struct cp_async_zfill<2, cache_op> { + + /// Copy with zero fill + CUTLASS_DEVICE + cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) { + using AccessType = Array; + + if (pred_guard) { + *static_cast(smem_ptr) = *static_cast(global_ptr); + } + else { + AccessType zeros; + zeros.clear(); + *static_cast(smem_ptr) = zeros; + } + } +}; + /// Partial specialization template < /// Size of the access in bytes