New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DataType] Initial support of fp8 (e4m3/e5m2) #14863
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for such great works! LGTM except some nits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR - I think adding support for fp8 is very positive.
I wonder if we should have an RFC, as this is adding a feature that may impact passes, schedules etc.
Adding fp8 is unlikely going impact passes or schedules in a significant way, as most passes can start opt out the dtype :) Having broad awareness is good, cross linking https://discuss.tvm.apache.org/t/tvm-support-for-fp8-a-discussion/14656, we can bring more discussions there as well |
Thank you everyone, looking forward for more FP8 improvements. This technology is still early so it could be possible that we will need to iterate a few times. It is great to get timely support for the community so we can start trying it out and learn |
looks like something unexpected happens with the pass # from tvm.script import tir as T
M = 64
N = 64
@tvm.script.ir_module
class MyModule:
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"global_symbol": "main"})
A = T.match_buffer(a, (M, N))
B = T.match_buffer(b, (M, N))
for i, j in T.grid(M, N):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
Traceback (most recent call last):
File "memory_copy.py", line 37, in <module>
rocm_mod = tvm.build(sch.mod, target="rocm --host=llvm")
File "/home/aiscuser/v-leiwang3/tvm/python/tvm/driver/build_module.py", line 281, in build
rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
File "/home/aiscuser/v-leiwang3/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 238, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
12: TVMFuncCall
11: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)>::AssignTypedLambda<tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}>(tvm::{lambda(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target)#6}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
10: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
9: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
8: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
7: tvm::transform::Pass::operator()(tvm::IRModule) const
6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
5: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
3: tvm::tir::transform::PrimFuncPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
2: _ZN3tvm7runtime13PackedFuncObj
1: tvm::runtime::TypedPackedFunc<tvm::tir::PrimFunc (tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::tir::transform::FP8StorageLegalize()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::tir::transform::FP8StorageLegalize()::{lambda(tvm::tir::PrimFunc, tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
0: tvm::tir::StorageLegalizer::Legalize(tvm::tir::PrimFunc)
File "/home/aiscuser/v-leiwang3/tvm/src/tir/transforms/unsupported_dtype_legalize.cc", line 478
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: func->buffer_map.size() == 0 (2 vs. 0) : This pass must be called after MakePackedAPI |
@LeiWang1999 it should have been fixed in #15102. |
Motivation
Recently NVIDIA announced official support of the fp8 data type: e4m3 and e5m2, the first one has 4 bits for exponent and 3 bits for mantissa while the second one has 5 bits for exponent and 2 bits for mantissa, and NVIDIA encourages using e4m3 for forward and e5m2 (larger dynamic range) for backward. Currently, TVM has no support for these data types, as the first step to support fp8, this PR adds new type codes for
e4m3_float8
ande5m2_float8
, and implement legalization passesFP8ComputeLegalize
andFP8StorageLegalize
so that we can use them for backends that do not have native fp8 support.Future Work
wgmma.mma_async.sync
ptx assembly to use fp8 tensor cores in Ada/Hopper.Notes
Infinity and NaN are not handled in our legalization pass (this behavior is the same as our previous BF16 legalization implementation) because it's costly to support them on the software side. It's the user's duty to guarantee that the conversion is safe.
Reference
cc @MasterJH5574 @masahi @tqchen @Hzfengsy