Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DataType] Initial support of fp8 (e4m3/e5m2) #14863

Merged
merged 29 commits into from Jun 1, 2023
Merged

[DataType] Initial support of fp8 (e4m3/e5m2) #14863

merged 29 commits into from Jun 1, 2023

Conversation

yzh119
Copy link
Member

@yzh119 yzh119 commented May 16, 2023

Motivation

Recently NVIDIA announced official support of the fp8 data type: e4m3 and e5m2, the first one has 4 bits for exponent and 3 bits for mantissa while the second one has 5 bits for exponent and 2 bits for mantissa, and NVIDIA encourages using e4m3 for forward and e5m2 (larger dynamic range) for backward. Currently, TVM has no support for these data types, as the first step to support fp8, this PR adds new type codes for e4m3_float8 and e5m2_float8, and implement legalization passes FP8ComputeLegalize and FP8StorageLegalize so that we can use them for backends that do not have native fp8 support.

Future Work

  • Emit CUDA fp8 primitives in CUDA codegen.
  • Support wgmma.mma_async.sync ptx assembly to use fp8 tensor cores in Ada/Hopper.
  • Support fp8 in dlpack.

Notes

Infinity and NaN are not handled in our legalization pass (this behavior is the same as our previous BF16 legalization implementation) because it's costly to support them on the software side. It's the user's duty to guarantee that the conversion is safe.

Reference

cc @MasterJH5574 @masahi @tqchen @Hzfengsy

@tvm-bot
Copy link
Collaborator

tvm-bot commented May 16, 2023

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

  • No users to tag found in teams: datatype See #10317 for details

Generated by tvm-bot

@github-actions github-actions bot requested a review from Hzfengsy May 16, 2023 21:53
Copy link
Member

@Hzfengsy Hzfengsy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for such great works! LGTM except some nits

src/tir/transforms/dtype_conversion.h Show resolved Hide resolved
src/tir/transforms/dtype_conversion.h Show resolved Hide resolved
Copy link
Contributor

@leandron leandron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR - I think adding support for fp8 is very positive.

I wonder if we should have an RFC, as this is adding a feature that may impact passes, schedules etc.

@tqchen
Copy link
Member

tqchen commented May 17, 2023

Adding fp8 is unlikely going impact passes or schedules in a significant way, as most passes can start opt out the dtype :)

Having broad awareness is good, cross linking https://discuss.tvm.apache.org/t/tvm-support-for-fp8-a-discussion/14656, we can bring more discussions there as well

@tqchen tqchen merged commit b13be93 into apache:main Jun 1, 2023
18 checks passed
@tqchen
Copy link
Member

tqchen commented Jun 1, 2023

Thank you everyone, looking forward for more FP8 improvements. This technology is still early so it could be possible that we will need to iterate a few times. It is great to get timely support for the community so we can start trying it out and learn

@LeiWang1999
Copy link
Contributor

LeiWang1999 commented Jun 13, 2023

looks like something unexpected happens with the pass FP8StorageLegalize under rocm backend. @yzh119

# from tvm.script import tir as T

M = 64
N = 64

@tvm.script.ir_module
class MyModule:
    @T.prim_func
    def main(a: T.handle, b: T.handle):
        T.func_attr({"global_symbol": "main"})
        A = T.match_buffer(a, (M, N))
        B = T.match_buffer(b, (M, N))
        for i, j in T.grid(M, N):
            with T.block("B"):
                vi, vj = T.axis.remap("SS", [i, j])
                B[vi, vj] = A[vi, vj] * 2.0

Traceback (most recent call last):
  File "memory_copy.py", line 37, in <module>
    rocm_mod = tvm.build(sch.mod, target="rocm --host=llvm")
  File "/home/aiscuser/v-leiwang3/tvm/python/tvm/driver/build_module.py", line 281, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "/home/aiscuser/v-leiwang3/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 238, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  12: TVMFuncCall
  11: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}>(tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  10: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
  9: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
  8: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: _ZN3tvm7runtime13PackedFuncObj
  1: tvm::runtime::TypedPackedFunc<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tir::transform::FP8StorageLegalize()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tir::transform::FP8StorageLegalize()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  0: tvm::tir::StorageLegalizer::Legalize(tvm::tir::PrimFunc)
  File "/home/aiscuser/v-leiwang3/tvm/src/tir/transforms/unsupported_dtype_legalize.cc", line 478
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------

  Check failed: func->buffer_map.size() == 0 (2 vs. 0) : This pass must be called after MakePackedAPI

@yzh119
Copy link
Member Author

yzh119 commented Jun 14, 2023

@LeiWang1999 it should have been fixed in #15102.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants