diff --git a/aten/src/ATen/AccumulateType.h b/aten/src/ATen/AccumulateType.h index f96f34e1e6b6d..0275ef099b03d 100644 --- a/aten/src/ATen/AccumulateType.h +++ b/aten/src/ATen/AccumulateType.h @@ -4,7 +4,9 @@ #include #include #include +#include #include +#include #include // Defines the accumulation type for a scalar type. @@ -87,6 +89,8 @@ MPS_ACC_TYPE(BFloat16, float); MPS_ACC_TYPE(Half, float); MPS_ACC_TYPE(Float8_e5m2, float); MPS_ACC_TYPE(Float8_e4m3fn, float); +MPS_ACC_TYPE(Float8_e5m2fnuz, float); +MPS_ACC_TYPE(Float8_e4m3fnuz, float); MPS_ACC_TYPE(float, float); MPS_ACC_TYPE(double, float); MPS_ACC_TYPE(int8_t, int64_t); @@ -107,6 +111,8 @@ CUDA_ACC_TYPE(BFloat16, float); CUDA_ACC_TYPE(Half, float); CUDA_ACC_TYPE(Float8_e5m2, float); CUDA_ACC_TYPE(Float8_e4m3fn, float); +CUDA_ACC_TYPE(Float8_e5m2fnuz, float); +CUDA_ACC_TYPE(Float8_e4m3fnuz, float); CUDA_ACC_TYPE(float, float); CUDA_ACC_TYPE(double, double); CUDA_ACC_TYPE(int8_t, int64_t); @@ -123,8 +129,8 @@ CUDA_ACC_TYPE(c10::complex, c10::complex); CPU_ACC_TYPE(BFloat16, float); CPU_ACC_TYPE(Half, float); CPU_ACC_TYPE(Float8_e5m2, float); -CPU_ACC_TYPE(Float8_e5m2fnuz, float); CPU_ACC_TYPE(Float8_e4m3fn, float); +CPU_ACC_TYPE(Float8_e5m2fnuz, float); CPU_ACC_TYPE(Float8_e4m3fnuz, float); CPU_ACC_TYPE(float, double); CPU_ACC_TYPE(double, double); diff --git a/aten/src/ATen/NumericUtils.h b/aten/src/ATen/NumericUtils.h index 06b25334bb13e..73da51c1a6446 100644 --- a/aten/src/ATen/NumericUtils.h +++ b/aten/src/ATen/NumericUtils.h @@ -7,7 +7,9 @@ #include #include #include +#include #include +#include #include #include @@ -80,6 +82,22 @@ inline C10_HOST_DEVICE bool _isnan(T val) { return val.isnan(); } +template < + typename T, + typename std::enable_if::value, int>:: + type = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return val.isnan(); +} + +template < + typename T, + typename std::enable_if::value, int>:: + type = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return val.isnan(); +} + // std::isinf isn't performant to use on integral types; it will // (uselessly) convert to floating point and then do the test. // This function is. @@ -118,6 +136,14 @@ inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val) { return false; } +inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val) { + return false; +} + +inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val) { + return false; +} + template C10_HOST_DEVICE inline T exp(T x) { static_assert( diff --git a/aten/src/ATen/OpMathType.h b/aten/src/ATen/OpMathType.h index ddb2ce71be05f..d00195b07e490 100644 --- a/aten/src/ATen/OpMathType.h +++ b/aten/src/ATen/OpMathType.h @@ -4,7 +4,9 @@ #include #include #include +#include #include +#include #include namespace at { @@ -31,6 +33,14 @@ struct OpMathType { using type = float; }; template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> struct OpMathType> { using type = c10::complex; }; diff --git a/aten/src/ATen/core/ATen_pch.h b/aten/src/ATen/core/ATen_pch.h index 1f36d0ab9f87b..57ca22bf4377a 100644 --- a/aten/src/ATen/core/ATen_pch.h +++ b/aten/src/ATen/core/ATen_pch.h @@ -110,6 +110,8 @@ #include #include #include +#include +#include #include #include #include diff --git a/aten/src/ATen/cuda/CUDADataType.h b/aten/src/ATen/cuda/CUDADataType.h index 3068eb787a837..92259edd63d7d 100644 --- a/aten/src/ATen/cuda/CUDADataType.h +++ b/aten/src/ATen/cuda/CUDADataType.h @@ -92,6 +92,13 @@ inline cudaDataType ScalarTypeToCudaDataType(const c10::ScalarType& scalar_type) case c10::ScalarType::Float8_e5m2: return CUDA_R_8F_E5M2; #endif +#else // USE_ROCM +#if ROCM_VERSION >= 60000 + case c10::ScalarType::Float8_e4m3fnuz: + return HIP_R_8F_E4M3_FNUZ; + case c10::ScalarType::Float8_e5m2fnuz: + return HIP_R_8F_E5M2_FNUZ; +#endif #endif default: TORCH_INTERNAL_ASSERT(false, "Cannot convert ScalarType ", scalar_type, " to cudaDataType.") diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 9995038c14cf2..8445397e4b916 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1324,9 +1324,9 @@ Tensor outer(const Tensor& self, const Tensor& vec2) { #if !defined(C10_MOBILE) -#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ - kBFloat16, kHalf, kFloat8_e5m2, kFloat8_e4m3fn, \ +#define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \ + kBFloat16, kHalf, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \ TYPE, NAME, __VA_ARGS__) #else #define _AT_DISPATCH_ADDMM_TYPES(TYPE, NAME, ...) \ diff --git a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp index 8958126d107e2..e337ea728c0e5 100644 --- a/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp @@ -90,10 +90,7 @@ void atan2_kernel(TensorIteratorBase& iter) { kHalf, \ kBool, \ kBFloat16, \ - kFloat8_e5m2, \ - kFloat8_e5m2fnuz, \ - kFloat8_e4m3fn, \ - kFloat8_e4m3fnuz, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) + AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) #define _AT_DISPATCH_ALL_TYPES_NO_BOOL(TYPE, NAME, ...) \ AT_DISPATCH_V2( \ TYPE, \ @@ -102,12 +99,10 @@ void atan2_kernel(TensorIteratorBase& iter) { kComplexHalf, \ kHalf, \ kBFloat16, \ - kFloat8_e5m2, \ - kFloat8_e4m3fn, \ - AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) + AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) #define _AT_DISPATCH_MUL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_V2(TYPE, NAME, AT_WRAP(__VA_ARGS__), \ - kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) + kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)) #else #define _AT_DISPATCH_ALL_TYPES_AND_BOOL(TYPE, NAME, ...) \ AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index 0567ccada8b1d..554eb1989efde 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -268,9 +268,9 @@ void gemm_core_( } #if !defined(C10_MOBILE) -#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ - kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \ +#define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \ + kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \ TYPE, NAME, __VA_ARGS__) #else #define _AT_DISPATCH_GEMM_TYPES(TYPE, NAME, ...) \ diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 8e6322fad57ae..2522c57a2d628 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -179,9 +179,9 @@ static void logit_kernel(TensorIteratorBase& iter, const Scalar& eps_scalar) { } #if !defined(C10_MOBILE) -#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ - kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, \ +#define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \ + kHalf, kBFloat16, kFloat8_e5m2, kFloat8_e4m3fn, kFloat8_e5m2fnuz, kFloat8_e4m3fnuz, \ TYPE, NAME, __VA_ARGS__) #else #define _AT_DISPATCH_ABS_TYPES(TYPE, NAME, ...) \ diff --git a/aten/src/ATen/native/cuda/CompareEQKernel.cu b/aten/src/ATen/native/cuda/CompareEQKernel.cu index 9966c3b085050..9496ae95d13b2 100644 --- a/aten/src/ATen/native/cuda/CompareEQKernel.cu +++ b/aten/src/ATen/native/cuda/CompareEQKernel.cu @@ -33,7 +33,7 @@ C10_NOINLINE void compare_eq_ne_kernel(TensorIteratorBase &iter, EqOpType op) { AT_DISPATCH_V2(iter.common_dtype(), "compare_eq_ne_cuda", AT_WRAP([&]() { opmath_symmetric_gpu_kernel_with_scalars( iter, CompareEqFunctor(op)); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBFloat16, kBool, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } void eq_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 81149085354da..44402e0c4bfba 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -35,7 +35,6 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) { ScalarType other_dtype = iter.dtype(1); if (dtype == kFloat8_e4m3fn) { switch (other_dtype) { -#if !defined(USE_ROCM) case kFloat: gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { return Float8_e4m3fn(value); @@ -51,14 +50,12 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) { return Float8_e4m3fn(value); }); break; -#endif /* !defined(USE_ROCM) */ default: gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fn x) { return x; }); break; } } else if (dtype == kFloat8_e5m2) { switch (other_dtype) { -#if !defined(USE_ROCM) case kFloat: gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { #ifdef AT_USE_NV_CVT_INTRINSICS @@ -89,11 +86,52 @@ void float8_copy_kernel_cuda(TensorIteratorBase &iter) { #endif }); break; -#endif /* !defined(USE_ROCM) */ default: gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2 x) { return x; }); break; } + } else if (dtype == kFloat8_e4m3fnuz) { + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { + return Float8_e4m3fnuz(value); + }); + break; + case kHalf: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { + return Float8_e4m3fnuz(value); + }); + break; + case kBFloat16: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { + return Float8_e4m3fnuz(value); + }); + break; + default: + gpu_kernel(iter, [] GPU_LAMBDA(Float8_e4m3fnuz x) { return x; }); + break; + } + } else if (dtype == kFloat8_e5m2fnuz) { + switch (other_dtype) { + case kFloat: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(float value) { + return Float8_e5m2fnuz(value); + }); + break; + case kHalf: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(Half value) { + return Float8_e5m2fnuz(value); + }); + break; + case kBFloat16: + gpu_kernel_nocast(iter, [] GPU_LAMBDA(BFloat16 value) { + return Float8_e5m2fnuz(value); + }); + break; + default: + gpu_kernel(iter, [] GPU_LAMBDA(Float8_e5m2fnuz x) { return x; }); + break; + } } else { TORCH_CHECK(false, "This supposed ot be called only for Float8 types"); } @@ -107,16 +145,14 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) { AT_DISPATCH_QINT_TYPES(dtype, "copy_", [&] { gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); }); - } else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn) { + } else if (dtype == kFloat8_e5m2 || dtype == kFloat8_e4m3fn || dtype == kFloat8_e5m2fnuz || dtype == kFloat8_e4m3fnuz) { float8_copy_kernel_cuda(iter); -#if !defined(USE_ROCM) } else if (isBitsType(dtype)) { TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting " "bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype); AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] { gpu_kernel_nocast(iter, [] GPU_LAMBDA(scalar_t x) { return x; }); }); -#endif /* !defined(USE_ROCM) */ } else { AT_DISPATCH_V2( dtype, "copy_", AT_WRAP([&] { diff --git a/aten/src/ATen/native/cuda/FillKernel.cu b/aten/src/ATen/native/cuda/FillKernel.cu index e7e1237a6f412..dc2ecf2db35b6 100644 --- a/aten/src/ATen/native/cuda/FillKernel.cu +++ b/aten/src/ATen/native/cuda/FillKernel.cu @@ -22,7 +22,7 @@ struct FillFunctor { void fill_kernel_cuda(TensorIterator& iter, const Scalar& value) { AT_DISPATCH_V2(iter.dtype(), "fill_cuda", AT_WRAP([&]() { gpu_kernel(iter, FillFunctor(value.to())); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, kFloat8_e4m3fn, kFloat8_e5m2, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kBool, kHalf, kBFloat16, AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } REGISTER_DISPATCH(fill_stub, &fill_kernel_cuda); diff --git a/aten/src/ATen/native/cuda/ROCmLoops.cuh b/aten/src/ATen/native/cuda/ROCmLoops.cuh index 75811d7ae6102..6c8196c586837 100644 --- a/aten/src/ATen/native/cuda/ROCmLoops.cuh +++ b/aten/src/ATen/native/cuda/ROCmLoops.cuh @@ -298,6 +298,33 @@ static void launch_kernel(int64_t N, const func_t& f, array_t data) {} } // namespace modern +template +void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) { + using traits = function_traits; + using arg0_t = typename traits::result_type; + constexpr int ntensors = traits::arity + 1; + + TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing()); + TORCH_INTERNAL_ASSERT(iter.ninputs() == traits::arity); + TORCH_INTERNAL_ASSERT(iter.noutputs() == 1); + TORCH_INTERNAL_ASSERT(!needs_dynamic_casting::check(iter)); + + at::detail::Array data; + for (int i = 0; i < ntensors; i++) { + data[i] = (char*)iter.data_ptr(i); + } + + int64_t numel = iter.numel(); + + auto offset_calc = ::make_offset_calculator(iter); + constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4; + legacy::launch_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) { + auto offsets = offset_calc.get(idx); + arg0_t* out = (arg0_t*)(data[0] + offsets[0]); + *out = legacy::invoke(f, &data.data[1], &offsets.data[1], 1); + }); +} + template void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) { using traits = function_traits; diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 3a91dae8741d5..b97e43f9e683c 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -525,8 +525,15 @@ static inline bool isSignedType(ScalarType t) { case ScalarType::ComplexFloat: case ScalarType::ComplexDouble: return true; - AT_FORALL_SCALAR_TYPES_AND5( - Half, Bool, BFloat16, Float8_e5m2, Float8_e4m3fn, CASE_SIGNED) + AT_FORALL_SCALAR_TYPES_AND7( + Half, + Bool, + BFloat16, + Float8_e5m2, + Float8_e4m3fn, + Float8_e5m2fnuz, + Float8_e4m3fnuz, + CASE_SIGNED) default: TORCH_CHECK(false, "Unknown ScalarType"); } diff --git a/c10/util/Float8_e4m3fn.h b/c10/util/Float8_e4m3fn.h index 8d3e339ca6196..bb6c0be9d076d 100644 --- a/c10/util/Float8_e4m3fn.h +++ b/c10/util/Float8_e4m3fn.h @@ -96,7 +96,7 @@ inline C10_HOST_DEVICE float fp8e4m3fn_to_fp32_value(uint8_t input) { * mantissa will shift into exponent, turning the biased exponent into 1, and * making mantissa normalized (i.e. without leading 1). */ -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) uint32_t renorm_shift = __clz(nonsign); #elif defined(__SYCL_DEVICE_ONLY__) // Note: zero is not a supported input into `__builtin_clz` diff --git a/c10/util/Float8_e4m3fnuz-inl.h b/c10/util/Float8_e4m3fnuz-inl.h index c1aab8bfe4dcc..e89eaeadd47b4 100644 --- a/c10/util/Float8_e4m3fnuz-inl.h +++ b/c10/util/Float8_e4m3fnuz-inl.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include C10_CLANG_DIAGNOSTIC_PUSH() @@ -12,21 +14,208 @@ namespace c10 { /// Constructors -C10_HOST_DEVICE inline Float8_e4m3fnuz::Float8_e4m3fnuz(float value) +inline C10_HOST_DEVICE Float8_e4m3fnuz::Float8_e4m3fnuz(float value) : x(detail::fp8e4m3fnuz_from_fp32_value(value)) {} /// Implicit conversions -C10_HOST_DEVICE inline Float8_e4m3fnuz::operator float() const { - return detail::fp8e4m3fnuz_to_fp32_value(x); +inline C10_HOST_DEVICE Float8_e4m3fnuz::operator float() const { + return detail::fp8_fnuz_to_fp32_value<4, 3>(x); } /// Special values helper -C10_HOST_DEVICE inline bool Float8_e4m3fnuz::isnan() const { +inline C10_HOST_DEVICE bool Float8_e4m3fnuz::isnan() const { return x == 0b10000000; } +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator+(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator-(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz +operator*(const Float8_e4m3fnuz& a, const Float8_e4m3fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/( + const Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(const Float8_e4m3fnuz& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator+=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator-=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator*=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz& operator/=( + Float8_e4m3fnuz& a, + const Float8_e4m3fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e4m3fnuz a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e4m3fnuz a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e4m3fnuz a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e4m3fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e4m3fnuz& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e4m3fnuz& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e4m3fnuz& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e4m3fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e4m3fnuz a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e4m3fnuz a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e4m3fnuz a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e4m3fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e4m3fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e4m3fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e4m3fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e4m3fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int a, Float8_e4m3fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int a, Float8_e4m3fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int a, Float8_e4m3fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int a, Float8_e4m3fnuz b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(Float8_e4m3fnuz a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(Float8_e4m3fnuz a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(Float8_e4m3fnuz a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(Float8_e4m3fnuz a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e4m3fnuz operator+(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator-(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator*(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e4m3fnuz operator/(int64_t a, Float8_e4m3fnuz b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e4m3fnuz to float. + } // namespace c10 namespace std { diff --git a/c10/util/Float8_e4m3fnuz.cpp b/c10/util/Float8_e4m3fnuz.cpp index 7bf301ff1b4a7..b25ccda22e6fc 100644 --- a/c10/util/Float8_e4m3fnuz.cpp +++ b/c10/util/Float8_e4m3fnuz.cpp @@ -1,276 +1,8 @@ #include -#include #include namespace c10 { -namespace detail { - -C10_HOST_DEVICE float fp8e4m3fnuz_to_fp32_value(uint8_t input) { - constexpr std::array e4m3fnuz_lut = { - 0.0f, - 0.0009765625f, - 0.001953125f, - 0.0029296875f, - 0.00390625f, - 0.0048828125f, - 0.005859375f, - 0.0068359375f, - 0.0078125f, - 0.0087890625f, - 0.009765625f, - 0.0107421875f, - 0.01171875f, - 0.0126953125f, - 0.013671875f, - 0.0146484375f, - 0.015625f, - 0.017578125f, - 0.01953125f, - 0.021484375f, - 0.0234375f, - 0.025390625f, - 0.02734375f, - 0.029296875f, - 0.03125f, - 0.03515625f, - 0.0390625f, - 0.04296875f, - 0.046875f, - 0.05078125f, - 0.0546875f, - 0.05859375f, - 0.0625f, - 0.0703125f, - 0.078125f, - 0.0859375f, - 0.09375f, - 0.1015625f, - 0.109375f, - 0.1171875f, - 0.125f, - 0.140625f, - 0.15625f, - 0.171875f, - 0.1875f, - 0.203125f, - 0.21875f, - 0.234375f, - 0.25f, - 0.28125f, - 0.3125f, - 0.34375f, - 0.375f, - 0.40625f, - 0.4375f, - 0.46875f, - 0.5f, - 0.5625f, - 0.625f, - 0.6875f, - 0.75f, - 0.8125f, - 0.875f, - 0.9375f, - 1.0f, - 1.125f, - 1.25f, - 1.375f, - 1.5f, - 1.625f, - 1.75f, - 1.875f, - 2.0f, - 2.25f, - 2.5f, - 2.75f, - 3.0f, - 3.25f, - 3.5f, - 3.75f, - 4.0f, - 4.5f, - 5.0f, - 5.5f, - 6.0f, - 6.5f, - 7.0f, - 7.5f, - 8.0f, - 9.0f, - 10.0f, - 11.0f, - 12.0f, - 13.0f, - 14.0f, - 15.0f, - 16.0f, - 18.0f, - 20.0f, - 22.0f, - 24.0f, - 26.0f, - 28.0f, - 30.0f, - 32.0f, - 36.0f, - 40.0f, - 44.0f, - 48.0f, - 52.0f, - 56.0f, - 60.0f, - 64.0f, - 72.0f, - 80.0f, - 88.0f, - 96.0f, - 104.0f, - 112.0f, - 120.0f, - 128.0f, - 144.0f, - 160.0f, - 176.0f, - 192.0f, - 208.0f, - 224.0f, - 240.0f, - std::numeric_limits::signaling_NaN(), - -0.0009765625f, - -0.001953125f, - -0.0029296875f, - -0.00390625f, - -0.0048828125f, - -0.005859375f, - -0.0068359375f, - -0.0078125f, - -0.0087890625f, - -0.009765625f, - -0.0107421875f, - -0.01171875f, - -0.0126953125f, - -0.013671875f, - -0.0146484375f, - -0.015625f, - -0.017578125f, - -0.01953125f, - -0.021484375f, - -0.0234375f, - -0.025390625f, - -0.02734375f, - -0.029296875f, - -0.03125f, - -0.03515625f, - -0.0390625f, - -0.04296875f, - -0.046875f, - -0.05078125f, - -0.0546875f, - -0.05859375f, - -0.0625f, - -0.0703125f, - -0.078125f, - -0.0859375f, - -0.09375f, - -0.1015625f, - -0.109375f, - -0.1171875f, - -0.125f, - -0.140625f, - -0.15625f, - -0.171875f, - -0.1875f, - -0.203125f, - -0.21875f, - -0.234375f, - -0.25f, - -0.28125f, - -0.3125f, - -0.34375f, - -0.375f, - -0.40625f, - -0.4375f, - -0.46875f, - -0.5f, - -0.5625f, - -0.625f, - -0.6875f, - -0.75f, - -0.8125f, - -0.875f, - -0.9375f, - -1.0f, - -1.125f, - -1.25f, - -1.375f, - -1.5f, - -1.625f, - -1.75f, - -1.875f, - -2.0f, - -2.25f, - -2.5f, - -2.75f, - -3.0f, - -3.25f, - -3.5f, - -3.75f, - -4.0f, - -4.5f, - -5.0f, - -5.5f, - -6.0f, - -6.5f, - -7.0f, - -7.5f, - -8.0f, - -9.0f, - -10.0f, - -11.0f, - -12.0f, - -13.0f, - -14.0f, - -15.0f, - -16.0f, - -18.0f, - -20.0f, - -22.0f, - -24.0f, - -26.0f, - -28.0f, - -30.0f, - -32.0f, - -36.0f, - -40.0f, - -44.0f, - -48.0f, - -52.0f, - -56.0f, - -60.0f, - -64.0f, - -72.0f, - -80.0f, - -88.0f, - -96.0f, - -104.0f, - -112.0f, - -120.0f, - -128.0f, - -144.0f, - -160.0f, - -176.0f, - -192.0f, - -208.0f, - -224.0f, - -240.0f, - }; - - return e4m3fnuz_lut[input]; -} - -} // namespace detail - static_assert( std::is_standard_layout_v, "c10::Float8_e4m3fnuz must be standard layout."); diff --git a/c10/util/Float8_e4m3fnuz.h b/c10/util/Float8_e4m3fnuz.h index 0b42c062a280a..6e066a5240f85 100644 --- a/c10/util/Float8_e4m3fnuz.h +++ b/c10/util/Float8_e4m3fnuz.h @@ -4,13 +4,11 @@ /// conversions to standard C types and basic arithmetic operations. Note that /// arithmetic operations are implemented by converting to floating point and /// performing the operation in float32. -/// /// Binary configuration remains the same as Float8_e4m3fn: /// s eeee mmm /// 1 sign bit /// 4 exponent bits /// 3 mantissa bits -/// /// The key differences versus Float8_e4m3fn are: /// bias = 8 /// no infinities or negative zero @@ -23,6 +21,7 @@ #include #include #include +#include #if defined(__cplusplus) && (__cplusplus >= 201103L) #include @@ -38,27 +37,11 @@ namespace c10 { namespace detail { -/* - * Convert a 8-bit floating-point number in fp8 E4M3FNUZ format, in bit - * representation, to a 32-bit floating-point number in IEEE single-precision - * format, in bit representation. - * - * @note The implementation doesn't use any floating-point operations. - */ -#if defined(__CUDA_ARCH__) || defined(__HIP__) -C10_HOST_DEVICE C10_API inline float fp8e4m3fnuz_to_fp32_value(uint8_t) { - CUDA_KERNEL_ASSERT(false && "e4m3fnuz is not supported by CUDA or HIP"); - return -1.0; -} -#else -C10_API float fp8e4m3fnuz_to_fp32_value(uint8_t input); -#endif - /* * Convert a 32-bit floating-point number in IEEE single-precision format to a * 8-bit floating-point number in fp8 E4M3FNUZ format, in bit representation. */ -C10_HOST_DEVICE inline uint8_t fp8e4m3fnuz_from_fp32_value(float f) { +inline C10_HOST_DEVICE uint8_t fp8e4m3fnuz_from_fp32_value(float f) { /* * Binary representation of 256.0f, which is the first value not representable * (i.e. the first value which would overflow in to the sign bit, resulting in @@ -70,7 +53,7 @@ C10_HOST_DEVICE inline uint8_t fp8e4m3fnuz_from_fp32_value(float f) { /* * A mask for converting fp32 numbers lower than fp8e4m3fnuz normal range - * into denormalized representation. + * into denorm representation * magic number: ((127 - 8) + (23 - 3) + 1) */ constexpr uint32_t denorm_mask = UINT32_C(0x8C) << 23; @@ -123,7 +106,6 @@ C10_HOST_DEVICE inline uint8_t fp8e4m3fnuz_from_fp32_value(float f) { } result |= sign >> 24; - return result; } @@ -133,7 +115,7 @@ struct alignas(1) Float8_e4m3fnuz { uint8_t x; struct from_bits_t {}; - static constexpr C10_HOST_DEVICE from_bits_t from_bits() { + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } diff --git a/c10/util/Float8_e5m2fnuz-inl.h b/c10/util/Float8_e5m2fnuz-inl.h index 8aad01f445842..3af233a87b844 100644 --- a/c10/util/Float8_e5m2fnuz-inl.h +++ b/c10/util/Float8_e5m2fnuz-inl.h @@ -1,6 +1,8 @@ #pragma once #include +#include +#include #include C10_CLANG_DIAGNOSTIC_PUSH() @@ -12,21 +14,212 @@ namespace c10 { /// Constructors -C10_HOST_DEVICE inline Float8_e5m2fnuz::Float8_e5m2fnuz(float value) +inline C10_HOST_DEVICE Float8_e5m2fnuz::Float8_e5m2fnuz(float value) : x(detail::fp8e5m2fnuz_from_fp32_value(value)) {} /// Implicit conversions -C10_HOST_DEVICE inline Float8_e5m2fnuz::operator float() const { - return detail::fp8e5m2fnuz_to_fp32_value(x); +inline C10_HOST_DEVICE Float8_e5m2fnuz::operator float() const { + return detail::fp8_fnuz_to_fp32_value<5, 2>(x); } /// Special values helpers -C10_HOST_DEVICE inline bool Float8_e5m2fnuz::isnan() const { +inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isnan() const { return x == 0b10000000; } +inline C10_HOST_DEVICE bool Float8_e5m2fnuz::isinf() const { + return false; +} + +/// Arithmetic + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator+(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) + static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator-(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) - static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz +operator*(const Float8_e5m2fnuz& a, const Float8_e5m2fnuz& b) { + return static_cast(a) * static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/( + const Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(const Float8_e5m2fnuz& a) { + return -static_cast(a); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator+=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a + b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator-=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a - b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator*=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a * b; + return a; +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz& operator/=( + Float8_e5m2fnuz& a, + const Float8_e5m2fnuz& b) { + a = a / b; + return a; +} + +/// Arithmetic with floats + +inline C10_HOST_DEVICE float operator+(Float8_e5m2fnuz a, float b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE float operator-(Float8_e5m2fnuz a, float b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE float operator*(Float8_e5m2fnuz a, float b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE float operator/(Float8_e5m2fnuz a, float b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE float operator+(float a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE float operator-(float a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE float operator*(float a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE float operator/(float a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE float& operator+=(float& a, const Float8_e5m2fnuz& b) { + return a += static_cast(b); +} +inline C10_HOST_DEVICE float& operator-=(float& a, const Float8_e5m2fnuz& b) { + return a -= static_cast(b); +} +inline C10_HOST_DEVICE float& operator*=(float& a, const Float8_e5m2fnuz& b) { + return a *= static_cast(b); +} +inline C10_HOST_DEVICE float& operator/=(float& a, const Float8_e5m2fnuz& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline C10_HOST_DEVICE double operator+(Float8_e5m2fnuz a, double b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE double operator-(Float8_e5m2fnuz a, double b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE double operator*(Float8_e5m2fnuz a, double b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE double operator/(Float8_e5m2fnuz a, double b) + __ubsan_ignore_float_divide_by_zero__ { + return static_cast(a) / b; +} + +inline C10_HOST_DEVICE double operator+(double a, Float8_e5m2fnuz b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE double operator-(double a, Float8_e5m2fnuz b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE double operator*(double a, Float8_e5m2fnuz b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE double operator/(double a, Float8_e5m2fnuz b) + __ubsan_ignore_float_divide_by_zero__ { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int a, Float8_e5m2fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int a, Float8_e5m2fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int a, Float8_e5m2fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int a, Float8_e5m2fnuz b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(Float8_e5m2fnuz a, int64_t b) { + return a + static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(Float8_e5m2fnuz a, int64_t b) { + return a - static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(Float8_e5m2fnuz a, int64_t b) { + return a * static_cast(b); +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(Float8_e5m2fnuz a, int64_t b) { + return a / static_cast(b); +} + +inline C10_HOST_DEVICE Float8_e5m2fnuz operator+(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) + b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator-(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) - b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator*(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) * b; +} +inline C10_HOST_DEVICE Float8_e5m2fnuz operator/(int64_t a, Float8_e5m2fnuz b) { + return static_cast(a) / b; +} + +/// NOTE: we do not define comparisons directly and instead rely on the implicit +/// conversion from c10::Float8_e5m2fnuz to float. + } // namespace c10 namespace std { diff --git a/c10/util/Float8_e5m2fnuz.cpp b/c10/util/Float8_e5m2fnuz.cpp index c98f6dc6d7613..de23b19af3e1a 100644 --- a/c10/util/Float8_e5m2fnuz.cpp +++ b/c10/util/Float8_e5m2fnuz.cpp @@ -1,276 +1,8 @@ #include -#include #include namespace c10 { -namespace detail { - -C10_HOST_DEVICE float fp8e5m2fnuz_to_fp32_value(uint8_t input) { - constexpr std::array e5m2fnuz_lut = { - 0.0f, - 7.62939453125e-06f, - 1.52587890625e-05f, - 2.288818359375e-05f, - 3.0517578125e-05f, - 3.814697265625e-05f, - 4.57763671875e-05f, - 5.340576171875e-05f, - 6.103515625e-05f, - 7.62939453125e-05f, - 9.1552734375e-05f, - 0.0001068115234375f, - 0.0001220703125f, - 0.000152587890625f, - 0.00018310546875f, - 0.000213623046875f, - 0.000244140625f, - 0.00030517578125f, - 0.0003662109375f, - 0.00042724609375f, - 0.00048828125f, - 0.0006103515625f, - 0.000732421875f, - 0.0008544921875f, - 0.0009765625f, - 0.001220703125f, - 0.00146484375f, - 0.001708984375f, - 0.001953125f, - 0.00244140625f, - 0.0029296875f, - 0.00341796875f, - 0.00390625f, - 0.0048828125f, - 0.005859375f, - 0.0068359375f, - 0.0078125f, - 0.009765625f, - 0.01171875f, - 0.013671875f, - 0.015625f, - 0.01953125f, - 0.0234375f, - 0.02734375f, - 0.03125f, - 0.0390625f, - 0.046875f, - 0.0546875f, - 0.0625f, - 0.078125f, - 0.09375f, - 0.109375f, - 0.125f, - 0.15625f, - 0.1875f, - 0.21875f, - 0.25f, - 0.3125f, - 0.375f, - 0.4375f, - 0.5f, - 0.625f, - 0.75f, - 0.875f, - 1.0f, - 1.25f, - 1.5f, - 1.75f, - 2.0f, - 2.5f, - 3.0f, - 3.5f, - 4.0f, - 5.0f, - 6.0f, - 7.0f, - 8.0f, - 10.0f, - 12.0f, - 14.0f, - 16.0f, - 20.0f, - 24.0f, - 28.0f, - 32.0f, - 40.0f, - 48.0f, - 56.0f, - 64.0f, - 80.0f, - 96.0f, - 112.0f, - 128.0f, - 160.0f, - 192.0f, - 224.0f, - 256.0f, - 320.0f, - 384.0f, - 448.0f, - 512.0f, - 640.0f, - 768.0f, - 896.0f, - 1024.0f, - 1280.0f, - 1536.0f, - 1792.0f, - 2048.0f, - 2560.0f, - 3072.0f, - 3584.0f, - 4096.0f, - 5120.0f, - 6144.0f, - 7168.0f, - 8192.0f, - 10240.0f, - 12288.0f, - 14336.0f, - 16384.0f, - 20480.0f, - 24576.0f, - 28672.0f, - 32768.0f, - 40960.0f, - 49152.0f, - 57344.0f, - std::numeric_limits::signaling_NaN(), - -7.62939453125e-06f, - -1.52587890625e-05f, - -2.288818359375e-05f, - -3.0517578125e-05f, - -3.814697265625e-05f, - -4.57763671875e-05f, - -5.340576171875e-05f, - -6.103515625e-05f, - -7.62939453125e-05f, - -9.1552734375e-05f, - -0.0001068115234375f, - -0.0001220703125f, - -0.000152587890625f, - -0.00018310546875f, - -0.000213623046875f, - -0.000244140625f, - -0.00030517578125f, - -0.0003662109375f, - -0.00042724609375f, - -0.00048828125f, - -0.0006103515625f, - -0.000732421875f, - -0.0008544921875f, - -0.0009765625f, - -0.001220703125f, - -0.00146484375f, - -0.001708984375f, - -0.001953125f, - -0.00244140625f, - -0.0029296875f, - -0.00341796875f, - -0.00390625f, - -0.0048828125f, - -0.005859375f, - -0.0068359375f, - -0.0078125f, - -0.009765625f, - -0.01171875f, - -0.013671875f, - -0.015625f, - -0.01953125f, - -0.0234375f, - -0.02734375f, - -0.03125f, - -0.0390625f, - -0.046875f, - -0.0546875f, - -0.0625f, - -0.078125f, - -0.09375f, - -0.109375f, - -0.125f, - -0.15625f, - -0.1875f, - -0.21875f, - -0.25f, - -0.3125f, - -0.375f, - -0.4375f, - -0.5f, - -0.625f, - -0.75f, - -0.875f, - -1.0f, - -1.25f, - -1.5f, - -1.75f, - -2.0f, - -2.5f, - -3.0f, - -3.5f, - -4.0f, - -5.0f, - -6.0f, - -7.0f, - -8.0f, - -10.0f, - -12.0f, - -14.0f, - -16.0f, - -20.0f, - -24.0f, - -28.0f, - -32.0f, - -40.0f, - -48.0f, - -56.0f, - -64.0f, - -80.0f, - -96.0f, - -112.0f, - -128.0f, - -160.0f, - -192.0f, - -224.0f, - -256.0f, - -320.0f, - -384.0f, - -448.0f, - -512.0f, - -640.0f, - -768.0f, - -896.0f, - -1024.0f, - -1280.0f, - -1536.0f, - -1792.0f, - -2048.0f, - -2560.0f, - -3072.0f, - -3584.0f, - -4096.0f, - -5120.0f, - -6144.0f, - -7168.0f, - -8192.0f, - -10240.0f, - -12288.0f, - -14336.0f, - -16384.0f, - -20480.0f, - -24576.0f, - -28672.0f, - -32768.0f, - -40960.0f, - -49152.0f, - -57344.0f, - }; - - return e5m2fnuz_lut[input]; -} - -} // namespace detail - static_assert( std::is_standard_layout_v, "c10::Float8_e5m2 must be standard layout."); diff --git a/c10/util/Float8_e5m2fnuz.h b/c10/util/Float8_e5m2fnuz.h index e09ce99fbb548..5418705737c66 100644 --- a/c10/util/Float8_e5m2fnuz.h +++ b/c10/util/Float8_e5m2fnuz.h @@ -4,13 +4,11 @@ /// conversions to standard C types and basic arithmetic operations. Note that /// arithmetic operations are implemented by converting to floating point and /// performing the operation in float32. -/// /// Binary configuration remains the same as e5m2: /// s eeeee mm /// 1 sign bit /// 5 exponent bits /// 2 mantissa bits -/// /// The key differences that e5m2fnuz brings are: /// bias = 16 /// no infinities or negative zero @@ -38,27 +36,11 @@ namespace c10 { namespace detail { -/* - * Convert a 8-bit floating-point number in fp8 E5M2FNUZ format, in bit - * representation, to a 32-bit floating-point number in IEEE single-precision - * format, in bit representation. - * - * @note The implementation doesn't use any floating-point operations. - */ -#if defined(__CUDA_ARCH__) || defined(__HIP__) -C10_HOST_DEVICE C10_API inline float fp8e5m2fnuz_to_fp32_value(uint8_t) { - CUDA_KERNEL_ASSERT(false && "e5m2fnuz is not supported by CUDA or HIP"); - return -1.0; -} -#else -C10_API float fp8e5m2fnuz_to_fp32_value(uint8_t input); -#endif - /* * Convert a 32-bit floating-point number in IEEE single-precision format to a * 8-bit floating-point number in fp8 E5M2 format, in bit representation. */ -C10_HOST_DEVICE inline uint8_t fp8e5m2fnuz_from_fp32_value(float f) { +inline C10_HOST_DEVICE uint8_t fp8e5m2fnuz_from_fp32_value(float f) { /* * Binary representation of 65536.0f, which is the first value not * representable (i.e. the first value which would overflow in to the sign @@ -76,7 +58,6 @@ C10_HOST_DEVICE inline uint8_t fp8e5m2fnuz_from_fp32_value(float f) { constexpr uint32_t denorm_mask = UINT32_C(0x85) << 23; uint32_t f_bits = fp32_to_bits(f); - uint32_t result = 0u; /* @@ -132,7 +113,7 @@ struct alignas(1) Float8_e5m2fnuz { uint8_t x; struct from_bits_t {}; - static constexpr C10_HOST_DEVICE from_bits_t from_bits() { + C10_HOST_DEVICE static constexpr from_bits_t from_bits() { return from_bits_t(); } @@ -143,6 +124,7 @@ struct alignas(1) Float8_e5m2fnuz { inline C10_HOST_DEVICE Float8_e5m2fnuz(float value); inline C10_HOST_DEVICE operator float() const; inline C10_HOST_DEVICE bool isnan() const; + inline C10_HOST_DEVICE bool isinf() const; }; C10_API std::ostream& operator<<( diff --git a/c10/util/Float8_fnuz_cvt.h b/c10/util/Float8_fnuz_cvt.h new file mode 100644 index 0000000000000..983063a0230fc --- /dev/null +++ b/c10/util/Float8_fnuz_cvt.h @@ -0,0 +1,58 @@ +#pragma once + +#include + +#include + +namespace c10::detail { + +/* + * Convert a 8-bit floating-point number in either f8 E4M3FNUZ or bf8 E5M2FNUZ + * format, in bit representation, to a 32-bit floating-point number. + */ +template +inline C10_HOST_DEVICE float fp8_fnuz_to_fp32_value(uint8_t x) { + static_assert((we == 4 && wm == 3) || (we == 5 && wm == 2)); + constexpr uint32_t weo = 8; + constexpr uint32_t wmo = 23; + + if (x == 0) { + return 0; + } + + if (x == 0x80) { + constexpr uint32_t ifNaN = 0x7F800001; + return fp32_from_bits(ifNaN); + } + + uint32_t mantissa = x & ((1 << wm) - 1); + uint32_t exponent = (x & 0x7F) >> wm; + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) + uint32_t renorm_shift = __clz(mantissa); +#elif defined(_MSC_VER) + unsigned long nonsign_bsr; + _BitScanReverse(&nonsign_bsr, (unsigned long)mantissa); + uint32_t renorm_shift = (uint32_t)nonsign_bsr ^ 31; +#else + uint32_t renorm_shift = __builtin_clz(mantissa); +#endif + uint32_t sh = 1 + renorm_shift - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + + const uint32_t exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)); + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + uint32_t sign = x >> 7; + uint32_t retval = (sign << 31) | (exponent << 23) | mantissa; + return fp32_from_bits(retval); +} + +} // namespace c10::detail diff --git a/torch/_C/_onnx.pyi b/torch/_C/_onnx.pyi index 376f461c35881..2e8e5a0c66117 100644 --- a/torch/_C/_onnx.pyi +++ b/torch/_C/_onnx.pyi @@ -25,6 +25,8 @@ class TensorProtoDataType(Enum): BFLOAT16 = ... FLOAT8E5M2 = ... FLOAT8E4M3FN = ... + FLOAT8E5M2FNUZ = ... + FLOAT8E4M3FNUZ = ... class OperatorExportTypes(Enum): ONNX = ... diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index ab6ce29e952a7..f7b0c5d91b385 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -85,6 +85,8 @@ torch.complex64: "at::kComplexFloat", torch.float8_e4m3fn: "at::kFloat8_e4m3fn", torch.float8_e5m2: "at::kFloat8_e5m2", + torch.float8_e4m3fnuz: "at::kFloat8_e4m3fnuz", + torch.float8_e5m2fnuz: "at::kFloat8_e5m2fnuz", } DEVICE_TO_ATEN = { diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 63bb79848f95d..922e0e1bdcc49 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -387,6 +387,10 @@ def triton_compute_type(dtype): triton_type_name = "float8e4nv" elif triton_type_name == "float8_e5m2": triton_type_name = "float8e5" + elif triton_type_name == "float8_e4m3fnuz": + triton_type_name = "float8e4b8" + elif triton_type_name == "float8_e5m2": + triton_type_name = "float8e5b16" return f"tl.{triton_type_name}" diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 7e4596e47690d..43f044abdc892 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -18,6 +18,10 @@ def signature_of(arg: Union[TensorArg, SizeArg], *, size_dtype: str) -> str: tye = "*fp8e4nv" elif arg.dtype == torch.float8_e5m2: tye = "*fp8e5" + elif arg.dtype == torch.float8_e4m3fnuz: + tye = "*fp8e4b8" + elif arg.dtype == torch.float8_e5m2fnuz: + tye = "*fp8e5b16" else: tye = JITFunction._type_of(arg.dtype) if V.graph.is_unspec_arg(arg.buffer): diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 08b7ca0b95066..e02b7c4fc11b6 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -91,6 +91,8 @@ def supported_dtype_of_cpp_wrapper(dtype, cuda): if cuda: supported_dtype.add(torch.float8_e4m3fn) supported_dtype.add(torch.float8_e5m2) + supported_dtype.add(torch.float8_e4m3fnuz) + supported_dtype.add(torch.float8_e5m2fnuz) return dtype in supported_dtype diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index bdac7a21956c4..0917e9e75dbf0 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -5520,7 +5520,12 @@ def is_col_major(shape, stride): return stride[0] == 1 and stride[1] == shape[0] def is_fp8_type(dtype): - return dtype in (torch.float8_e4m3fn, torch.float8_e5m2) + return dtype in ( + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + ) torch._check( self.dim() == 2 and mat2.dim() == 2, diff --git a/torch/_tensor.py b/torch/_tensor.py index 7c799d1f62bbc..607c46453d6c4 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -397,6 +397,8 @@ def _reduce_ex_internal(self, proto): v3_dtypes = [ torch.float8_e5m2, torch.float8_e4m3fn, + torch.float8_e5m2fnuz, + torch.float8_e4m3fnuz, torch.bits8, torch.bits16, torch.bits1x8, diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 2acf049a384aa..c8f24bef1b51d 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -81,6 +81,8 @@ def _get_allowed_globals(): torch.complex128, torch.float8_e5m2, torch.float8_e4m3fn, + torch.float8_e5m2fnuz, + torch.float8_e4m3fnuz, torch.float16, torch.float32, torch.float64, diff --git a/torch/csrc/inductor/aoti_torch/c/shim.h b/torch/csrc/inductor/aoti_torch/c/shim.h index ac36a1941bf52..1fb716e30cade 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim.h +++ b/torch/csrc/inductor/aoti_torch/c/shim.h @@ -85,6 +85,8 @@ AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cuda(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2fnuz(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fnuz(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bfloat16(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float16(); AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float32(); diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index aa1ca63ee9fef..a068206bb1ff2 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -74,6 +74,14 @@ int32_t aoti_torch_dtype_float8_e4m3fn() { return (int32_t)c10::ScalarType::Float8_e4m3fn; } +int32_t aoti_torch_dtype_float8_e5m2fnuz() { + return (int32_t)c10::ScalarType::Float8_e5m2fnuz; +} + +int32_t aoti_torch_dtype_float8_e4m3fnuz() { + return (int32_t)c10::ScalarType::Float8_e4m3fnuz; +} + int32_t aoti_torch_dtype_bfloat16() { return (int32_t)c10::ScalarType::BFloat16; } diff --git a/torch/csrc/jit/passes/onnx/helper.cpp b/torch/csrc/jit/passes/onnx/helper.cpp index 2f4757546bc14..d6b2a6385fab4 100644 --- a/torch/csrc/jit/passes/onnx/helper.cpp +++ b/torch/csrc/jit/passes/onnx/helper.cpp @@ -91,8 +91,12 @@ c10::optional ONNXTypeToATenType(int32_t onnx_type) { return at::kBFloat16; case ::torch::onnx::TensorProto_DataType_FLOAT8E5M2: return at::kFloat8_e5m2; + case ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ: + return at::kFloat8_e5m2fnuz; case ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN: return at::kFloat8_e4m3fn; + case ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ: + return at::kFloat8_e4m3fnuz; default: TORCH_CHECK( false, diff --git a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp index 1467a63a134c0..ef6342daff758 100644 --- a/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp +++ b/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp @@ -31,6 +31,8 @@ static const std::unordered_map {c10::kBFloat16, 15}, {c10::kFloat8_e4m3fn, 16}, {c10::kFloat8_e5m2, 17}, + {c10::kFloat8_e4m3fnuz, 18}, + {c10::kFloat8_e5m2fnuz, 19}, }; static int64_t ScalarTypeToONNXType(const c10::ScalarType& st) { diff --git a/torch/csrc/jit/serialization/export.cpp b/torch/csrc/jit/serialization/export.cpp index 16a0a7bcdce80..7f9078787f09e 100644 --- a/torch/csrc/jit/serialization/export.cpp +++ b/torch/csrc/jit/serialization/export.cpp @@ -469,6 +469,10 @@ onnx::TensorProto_DataType ATenTypeToOnnxType(at::ScalarType at_type) { return onnx_torch::TensorProto_DataType_FLOAT8E4M3FN; case at::kFloat8_e5m2: return onnx_torch::TensorProto_DataType_FLOAT8E5M2; + case at::kFloat8_e4m3fnuz: + return onnx_torch::TensorProto_DataType_FLOAT8E4M3FNUZ; + case at::kFloat8_e5m2fnuz: + return onnx_torch::TensorProto_DataType_FLOAT8E5M2FNUZ; default: TORCH_CHECK( false, diff --git a/torch/csrc/jit/tensorexpr/types.cpp b/torch/csrc/jit/tensorexpr/types.cpp index 75dc8ec23f274..3791fead0da19 100644 --- a/torch/csrc/jit/tensorexpr/types.cpp +++ b/torch/csrc/jit/tensorexpr/types.cpp @@ -74,8 +74,15 @@ int Dtype::byte_size() const { scalar_size = sizeof(Type); \ break; - AT_FORALL_SCALAR_TYPES_AND5( - Bool, Half, BFloat16, Float8_e5m2, Float8_e4m3fn, TYPE_CASE); + AT_FORALL_SCALAR_TYPES_AND7( + Bool, + Half, + BFloat16, + Float8_e5m2, + Float8_e4m3fn, + Float8_e5m2fnuz, + Float8_e4m3fnuz, + TYPE_CASE); TYPE_CASE(c10::quint8, QUInt8); TYPE_CASE(c10::qint8, QInt8); #undef TYPE_CASE diff --git a/torch/csrc/onnx/back_compat.h b/torch/csrc/onnx/back_compat.h index d5e58c8f9d874..9afefe345388f 100644 --- a/torch/csrc/onnx/back_compat.h +++ b/torch/csrc/onnx/back_compat.h @@ -12,8 +12,14 @@ namespace torch::onnx { // ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN constexpr auto TensorProto_DataType_FLOAT8E4M3FN = static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(17); +// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ +constexpr auto TensorProto_DataType_FLOAT8E4M3FNUZ = + static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(18); // ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2 constexpr auto TensorProto_DataType_FLOAT8E5M2 = static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(19); +// ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ +constexpr auto TensorProto_DataType_FLOAT8E5M2FNUZ = + static_cast<::ONNX_NAMESPACE::TensorProto_DataType>(20); } // namespace torch::onnx diff --git a/torch/csrc/onnx/init.cpp b/torch/csrc/onnx/init.cpp index 37e5341fe3f63..825ed46e11a50 100644 --- a/torch/csrc/onnx/init.cpp +++ b/torch/csrc/onnx/init.cpp @@ -274,7 +274,11 @@ void initONNXBindings(PyObject* module) { .value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128) .value("BFLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) .value("FLOAT8E4M3FN", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FN) - .value("FLOAT8E5M2", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2); + .value( + "FLOAT8E4M3FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E4M3FNUZ) + .value("FLOAT8E5M2", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2) + .value( + "FLOAT8E5M2FNUZ", ::torch::onnx::TensorProto_DataType_FLOAT8E5M2FNUZ); py::enum_(onnx, "OperatorExportTypes") .value("ONNX", OperatorExportTypes::ONNX) diff --git a/torch/csrc/utils/python_scalars.h b/torch/csrc/utils/python_scalars.h index 293952c1de349..34d1cab182f29 100644 --- a/torch/csrc/utils/python_scalars.h +++ b/torch/csrc/utils/python_scalars.h @@ -152,8 +152,8 @@ inline PyObject* load_scalar(void* data, at::ScalarType scalarType) { return PyFloat_FromDouble(at::convert( *(at::Float8_e5m2fnuz*)data)); case at::kFloat8_e4m3fnuz: - return PyFloat_FromDouble(at::convert( - *(at::Float8_e5m2fnuz*)data)); + return PyFloat_FromDouble(at::convert( + *(at::Float8_e4m3fnuz*)data)); default: throw std::runtime_error("invalid type"); } diff --git a/torch/fx/graph.py b/torch/fx/graph.py index e98be65e51ceb..35e19135bba5c 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -208,6 +208,8 @@ def _rename_object(self, obj: Any, name: str): torch.float16: 'f16', torch.float8_e4m3fn: 'f8e4m3fn', torch.float8_e5m2: 'f8e5m2', + torch.float8_e4m3fnuz: 'f8e4m3fnuz', + torch.float8_e5m2fnuz: 'f8e5m2fnuz', torch.complex32: 'c32', torch.complex64: 'c64', torch.complex128: 'c128', diff --git a/torch/onnx/_type_utils.py b/torch/onnx/_type_utils.py index 2f2b6cf63f4eb..0470793ca7709 100644 --- a/torch/onnx/_type_utils.py +++ b/torch/onnx/_type_utils.py @@ -33,6 +33,8 @@ "BFloat16", "Float8E5M2", "Float8E4M3FN", + "Float8E5M2FNUZ", + "Float8E4M3FNUZ", "Undefined", ] @@ -55,6 +57,8 @@ "bfloat16", "float8_e5m2", "float8_e4m3fn", + "float8_e5m2fnuz", + "float8_e4m3fnuz", ] @@ -96,7 +100,9 @@ class JitScalarType(enum.IntEnum): BFLOAT16 = enum.auto() # 15 FLOAT8E5M2 = enum.auto() # 16 FLOAT8E4M3FN = enum.auto() # 17 - UNDEFINED = enum.auto() # 18 + FLOAT8E5M2FNUZ = enum.auto() # 18 + FLOAT8E4M3FNUZ = enum.auto() # 19 + UNDEFINED = enum.auto() # 20 @classmethod @_beartype.beartype @@ -286,6 +292,8 @@ def valid_torch_name(torch_name: Union[TorchName, str]) -> bool: JitScalarType.BFLOAT16: "BFloat16", JitScalarType.FLOAT8E5M2: "Float8E5M2", JitScalarType.FLOAT8E4M3FN: "Float8E4M3FN", + JitScalarType.FLOAT8E5M2FNUZ: "Float8E5M2FNUZ", + JitScalarType.FLOAT8E4M3FNUZ: "Float8E4M3FNUZ", JitScalarType.UNDEFINED: "Undefined", } @@ -312,6 +320,8 @@ def valid_torch_name(torch_name: Union[TorchName, str]) -> bool: JitScalarType.BFLOAT16: "bfloat16", JitScalarType.FLOAT8E5M2: "float8_e5m2", JitScalarType.FLOAT8E4M3FN: "float8_e4m3fn", + JitScalarType.FLOAT8E5M2FNUZ: "float8_e5m2fnuz", + JitScalarType.FLOAT8E4M3FNUZ: "float8_e4m3fnuz", } _TORCH_NAME_TO_SCALAR_TYPE: Dict[TorchName, JitScalarType] = { @@ -338,6 +348,8 @@ def valid_torch_name(torch_name: Union[TorchName, str]) -> bool: JitScalarType.QINT32: _C_onnx.TensorProtoDataType.INT32, JitScalarType.FLOAT8E5M2: _C_onnx.TensorProtoDataType.FLOAT8E5M2, JitScalarType.FLOAT8E4M3FN: _C_onnx.TensorProtoDataType.FLOAT8E4M3FN, + JitScalarType.FLOAT8E5M2FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E5M2FNUZ, + JitScalarType.FLOAT8E4M3FNUZ: _C_onnx.TensorProtoDataType.FLOAT8E4M3FNUZ, } # source of truth is @@ -361,6 +373,8 @@ def valid_torch_name(torch_name: Union[TorchName, str]) -> bool: JitScalarType.BFLOAT16: torch.bfloat16, JitScalarType.FLOAT8E5M2: torch.float8_e5m2, JitScalarType.FLOAT8E4M3FN: torch.float8_e4m3fn, + JitScalarType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz, + JitScalarType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz, } _DTYPE_TO_SCALAR_TYPE = {v: k for k, v in _SCALAR_TYPE_TO_DTYPE.items()} diff --git a/torch/storage.py b/torch/storage.py index f65c0806accda..2217b28f97b99 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -207,6 +207,14 @@ def float8_e4m3fn(self): """Casts this storage to float8_e4m3fn type""" return self._to(torch.float8_e4m3fn) + def float8_e5m2fnuz(self): + """Casts this storage to float8_e5m2fnuz type""" + return self._to(torch.float8_e5m2fnuz) + + def float8_e4m3fnuz(self): + """Casts this storage to float8_e4m3fnuz type""" + return self._to(torch.float8_e4m3fnuz) + def is_pinned(self, device: Union[str, torch.device] = 'cuda'): r"""Determine whether the CPU storage is already pinned on device. @@ -1070,6 +1078,16 @@ def float8_e4m3fn(self): _warn_typed_storage_removal() return self._to(torch.float8_e4m3fn) + def float8_e5m2fnuz(self): + """Casts this storage to float8_e5m2fnuz type""" + _warn_typed_storage_removal() + return self._to(torch.float8_e5m2fnuz) + + def float8_e4m3fnuz(self): + """Casts this storage to float8_e4m3fnuz type""" + _warn_typed_storage_removal() + return self._to(torch.float8_e4m3fnuz) + @classmethod def from_file(cls, filename, shared, size): """from_file(filename, shared=False, size=0) -> Storage diff --git a/torch/testing/_creation.py b/torch/testing/_creation.py index d02de60d35665..0b01b172a4774 100644 --- a/torch/testing/_creation.py +++ b/torch/testing/_creation.py @@ -20,7 +20,12 @@ torch.uint64, ] _FLOATING_TYPES = [torch.float16, torch.bfloat16, torch.float32, torch.float64] -_FLOATING_8BIT_TYPES = [torch.float8_e4m3fn, torch.float8_e5m2] +_FLOATING_8BIT_TYPES = [ + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, +] _COMPLEX_TYPES = [torch.complex32, torch.complex64, torch.complex128] _BOOLEAN_OR_INTEGRAL_TYPES = [torch.bool, *_INTEGRAL_TYPES] _FLOATING_OR_COMPLEX_TYPES = [*_FLOATING_TYPES, *_COMPLEX_TYPES]