Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions include/cutlass/arch/memory_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,40 @@ static const uint32_t OOB_NAN_F16x2 = ((OOB_NAN_F16 << 16) | OOB_NAN_F16);

////////////////////////////////////////////////////////////////////////////////////////////////////

/// Fallback specialization for 1 byte copies
template <
/// Cache operation - ignored
CacheOperation::Kind cache_op>
struct cp_async<1, cache_op> {

/// Copy
CUTLASS_DEVICE
cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
using AccessType = Array<uint8_t, 1>;

if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
}
};

/// Fallback specialization for 2 byte copies
template <
/// Cache operation - ignored
CacheOperation::Kind cache_op>
struct cp_async<2, cache_op> {

/// Copy
CUTLASS_DEVICE
cp_async(void *smem_ptr, void const *global_ptr, bool pred_guard = true) {
using AccessType = Array<uint8_t, 2>;

if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
}
};

/// Partial specialization
template <
/// Size of the access in bytes
Expand Down Expand Up @@ -143,6 +177,50 @@ struct cp_async<SizeInBytes, CacheOperation::Always> {
}
};

/// Fallback specialization for 1 byte copies
template <
/// Cache operation - ignored
CacheOperation::Kind cache_op>
struct cp_async_zfill<1, cache_op> {

/// Copy with zero fill
CUTLASS_DEVICE
cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
using AccessType = Array<uint8_t, 1>;

if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
}
};

/// Fallback specialization for 2 byte copies
template <
/// Cache operation - ignored
CacheOperation::Kind cache_op>
struct cp_async_zfill<2, cache_op> {

/// Copy with zero fill
CUTLASS_DEVICE
cp_async_zfill(void *smem_ptr, void const *global_ptr, bool pred_guard) {
using AccessType = Array<uint8_t, 2>;

if (pred_guard) {
*static_cast<AccessType *>(smem_ptr) = *static_cast<AccessType const *>(global_ptr);
}
else {
AccessType zeros;
zeros.clear();
*static_cast<AccessType *>(smem_ptr) = zeros;
}
}
};

/// Partial specialization
template <
/// Size of the access in bytes
Expand Down
2 changes: 1 addition & 1 deletion include/cutlass/arch/mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ struct OpMultiplyAddFastF16 {};

/// Tag indicating the input is converted to 2 (big and small) TF32 components
// Perform 3xTF32 or 4xTF32 for every F32 output element
struct OpMultiplyAddFastF32 {}
struct OpMultiplyAddFastF32 {};

/// Tag indicating the input is converted to 2 (big and small) TF32 components
// Perform 3xTF32 or 4xTF32 for every complex<F32> output element
Expand Down