Skip to content

Commit

Permalink
additional support for float8_e4m3fnuz and _e5m2fnuz (pytorch#115214)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffdaily authored and pytorchmergebot committed Jan 22, 2024
1 parent 56ef5af commit 01abb5a
Show file tree
Hide file tree
Showing 43 changed files with 708 additions and 625 deletions.
8 changes: 7 additions & 1 deletion aten/src/ATen/AccumulateType.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
#include <c10/core/ScalarType.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>

// Defines the accumulation type for a scalar type.
Expand Down Expand Up @@ -87,6 +89,8 @@ MPS_ACC_TYPE(BFloat16, float);
MPS_ACC_TYPE(Half, float);
MPS_ACC_TYPE(Float8_e5m2, float);
MPS_ACC_TYPE(Float8_e4m3fn, float);
MPS_ACC_TYPE(Float8_e5m2fnuz, float);
MPS_ACC_TYPE(Float8_e4m3fnuz, float);
MPS_ACC_TYPE(float, float);
MPS_ACC_TYPE(double, float);
MPS_ACC_TYPE(int8_t, int64_t);
Expand All @@ -107,6 +111,8 @@ CUDA_ACC_TYPE(BFloat16, float);
CUDA_ACC_TYPE(Half, float);
CUDA_ACC_TYPE(Float8_e5m2, float);
CUDA_ACC_TYPE(Float8_e4m3fn, float);
CUDA_ACC_TYPE(Float8_e5m2fnuz, float);
CUDA_ACC_TYPE(Float8_e4m3fnuz, float);
CUDA_ACC_TYPE(float, float);
CUDA_ACC_TYPE(double, double);
CUDA_ACC_TYPE(int8_t, int64_t);
Expand All @@ -123,8 +129,8 @@ CUDA_ACC_TYPE(c10::complex<double>, c10::complex<double>);
CPU_ACC_TYPE(BFloat16, float);
CPU_ACC_TYPE(Half, float);
CPU_ACC_TYPE(Float8_e5m2, float);
CPU_ACC_TYPE(Float8_e5m2fnuz, float);
CPU_ACC_TYPE(Float8_e4m3fn, float);
CPU_ACC_TYPE(Float8_e5m2fnuz, float);
CPU_ACC_TYPE(Float8_e4m3fnuz, float);
CPU_ACC_TYPE(float, double);
CPU_ACC_TYPE(double, double);
Expand Down
26 changes: 26 additions & 0 deletions aten/src/ATen/NumericUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include <c10/macros/Macros.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>
#include <c10/util/complex.h>

Expand Down Expand Up @@ -80,6 +82,22 @@ inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}

template <
typename T,
typename std::enable_if<std::is_same<T, at::Float8_e5m2fnuz>::value, int>::
type = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}

template <
typename T,
typename std::enable_if<std::is_same<T, at::Float8_e4m3fnuz>::value, int>::
type = 0>
inline C10_HOST_DEVICE bool _isnan(T val) {
return val.isnan();
}

// std::isinf isn't performant to use on integral types; it will
// (uselessly) convert to floating point and then do the test.
// This function is.
Expand Down Expand Up @@ -118,6 +136,14 @@ inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val) {
return false;
}

inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val) {
return false;
}

inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val) {
return false;
}

template <typename T>
C10_HOST_DEVICE inline T exp(T x) {
static_assert(
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/OpMathType.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
#include <c10/util/BFloat16.h>
#include <c10/util/Exception.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/Half.h>

namespace at {
Expand All @@ -31,6 +33,14 @@ struct OpMathType<at::Float8_e4m3fn> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e5m2fnuz> {
using type = float;
};
template <>
struct OpMathType<at::Float8_e4m3fnuz> {
using type = float;
};
template <>
struct OpMathType<c10::complex<Half>> {
using type = c10::complex<float>;
};
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/ATen_pch.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@
#include <c10/util/Flags.h>
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e5m2.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e5m2fnuz.h>
#include <c10/util/FunctionRef.h>
#include <c10/util/Half.h>
#include <c10/util/IdWrapper.h>
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/cuda/CUDADataType.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type)
case c10::ScalarType::Float8_e5m2:
return CUDA_R_8F_E5M2;
#endif
#else // USE_ROCM
#if ROCM_VERSION >= 60000
case c10::ScalarType::Float8_e4m3fnuz:
return HIP_R_8F_E4M3_FNUZ;
case c10::ScalarType::Float8_e5m2fnuz:
return HIP_R_8F_E5M2_FNUZ;
#endif
#endif
default:
TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.")
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1324,9 +1324,9 @@ Tensor outer(const Tensor& self, const Tensor& vec2) {


#if !defined(C10_MOBILE)
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kBFloat16, kHalf, kFloat8_e5m2, kFloat8_e4m3fn, \
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kBFloat16, kHalf, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \
Expand Down
11 changes: 3 additions & 8 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,7 @@ void atan2_kernel(TensorIteratorBase& iter) {
kHalf, \
kBool, \
kBFloat16, \
kFloat8_e5m2, \
kFloat8_e5m2fnuz, \
kFloat8_e4m3fn, \
kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \
AT_DISPATCH_V2( \
TYPE, \
Expand All @@ -102,12 +99,10 @@ void atan2_kernel(TensorIteratorBase& iter) {
kComplexHalf, \
kHalf, \
kBFloat16, \
kFloat8_e5m2, \
kFloat8_e4m3fn, \
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES))
#else
#define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cpu/BlasKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,9 @@ void gemm_core_(
}

#if !defined(C10_MOBILE)
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/cpu/UnaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ static void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) {
}

#if !defined(C10_MOBILE)
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \
kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \
TYPE, NAME, __VA_ARGS__)
#else
#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/CompareEQKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) {
AT_DISPATCH_V2(iter.common_dtype(), "compare_eq_ne_cuda", AT_WRAP([&]() {
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
iter, CompareEqFunctor<scalar_t>(op));
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}

void eq_kernel_cuda(TensorIteratorBase& iter) {
Expand Down
50 changes: 43 additions & 7 deletions aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
ScalarType other_dtype = iter.dtype(1);
if (dtype == kFloat8_e4m3fn) {
switch (other_dtype) {
#if !defined(USE_ROCM)
case kFloat:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
return Float8_e4m3fn(value);
Expand All @@ -51,14 +50,12 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
return Float8_e4m3fn(value);
});
break;
#endif /* !defined(USE_ROCM) */
default:
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fn x) { return x; });
break;
}
} else if (dtype == kFloat8_e5m2) {
switch (other_dtype) {
#if !defined(USE_ROCM)
case kFloat:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
#ifdef AT_USE_NV_CVT_INTRINSICS
Expand Down Expand Up @@ -89,11 +86,52 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
#endif
});
break;
#endif /* !defined(USE_ROCM) */
default:
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2 x) { return x; });
break;
}
} else if (dtype == kFloat8_e4m3fnuz) {
switch (other_dtype) {
case kFloat:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
return Float8_e4m3fnuz(value);
});
break;
case kHalf:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
return Float8_e4m3fnuz(value);
});
break;
case kBFloat16:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
return Float8_e4m3fnuz(value);
});
break;
default:
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fnuz x) { return x; });
break;
}
} else if (dtype == kFloat8_e5m2fnuz) {
switch (other_dtype) {
case kFloat:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) {
return Float8_e5m2fnuz(value);
});
break;
case kHalf:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) {
return Float8_e5m2fnuz(value);
});
break;
case kBFloat16:
gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) {
return Float8_e5m2fnuz(value);
});
break;
default:
gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; });
break;
}
} else {
TORCH_CHECK(false, "This supposed ot be called only for Float8 types");
}
Expand All @@ -107,16 +145,14 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] {
gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn) {
} else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) {
float8_copy_kernel_cuda(iter);
#if !defined(USE_ROCM)
} else if (isBitsType(dtype)) {
TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
"bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; });
});
#endif /* !defined(USE_ROCM) */
} else {
AT_DISPATCH_V2(
dtype, "copy_", AT_WRAP([&] {
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/FillKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct FillFunctor {
void fill_kernel_cuda(TensorIterator& iter, const Scalar& value) {
AT_DISPATCH_V2(iter.dtype(), "fill_cuda", AT_WRAP([&]() {
gpu_kernel(iter, FillFunctor<scalar_t>(value.to<scalar_t>()));
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
}

REGISTER_DISPATCH(fill_stub, &fill_kernel_cuda);
Expand Down
27 changes: 27 additions & 0 deletions aten/src/ATen/native/cuda/ROCmLoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,33 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {}
} // namespace modern


template <typename func_t>
void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
using traits = function_traits<func_t>;
using arg0_t = typename traits::result_type;
constexpr int ntensors = traits::arity + 1;

TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity);
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
TORCH_INTERNAL_ASSERT(!needs_dynamic_casting<func_t>::check(iter));

at::detail::Array<char*, ntensors> data;
for (int i = 0; i < ntensors; i++) {
data[i] = (char*)iter.data_ptr(i);
}

int64_t numel = iter.numel();

auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
legacy::launch_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
*out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1);
});
}

template <typename func_t>
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
using traits = function_traits<func_t>;
Expand Down
11 changes: 9 additions & 2 deletions c10/core/ScalarType.h
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,15 @@ static inline bool isSignedType(ScalarType t) {
case ScalarType::ComplexFloat:
case ScalarType::ComplexDouble:
return true;
AT_FORALL_SCALAR_TYPES_AND5(
Half, Bool, BFloat16, Float8_e5m2, Float8_e4m3fn, CASE_SIGNED)
AT_FORALL_SCALAR_TYPES_AND7(
Half,
Bool,
BFloat16,
Float8_e5m2,
Float8_e4m3fn,
Float8_e5m2fnuz,
Float8_e4m3fnuz,
CASE_SIGNED)
default:
TORCH_CHECK(false, "Unknown ScalarType");
}
Expand Down
2 changes: 1 addition & 1 deletion c10/util/Float8_e4m3fn.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) {
* mantissa will shift into exponent, turning the biased exponent into 1, and
* making mantissa normalized (i.e. without leading 1).
*/
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
uint32_t renorm_shift = __clz(nonsign);
#elif defined(__SYCL_DEVICE_ONLY__)
// Note: zero is not a supported input into `__builtin_clz`
Expand Down

0 comments on commit 01abb5a

Please sign in to comment.