Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP] CUDA support for mixed precision pass #8294

Closed
AndrewZhaoLuo opened this issue Jun 21, 2021 · 11 comments
Closed

[AMP] CUDA support for mixed precision pass #8294

AndrewZhaoLuo opened this issue Jun 21, 2021 · 11 comments

Comments

@AndrewZhaoLuo
Copy link
Contributor

Solve issues and make modifications to support CUDA for mixed precision pass here: #8069

Current initial issues as described by @Lunderberg

On the cuda side, it's failing a check that requires 16-bit floats to be used in pairs.

Check failed: lanes % 2 == 0 (1 vs. 0) : only support even lane for half type

This issue is completed when unit tests can pass for CUDA target.

@AndrewZhaoLuo
Copy link
Contributor Author

cc @Lunderberg @masahi

@masahi
Copy link
Member

masahi commented Jun 22, 2021

Got this error when running test_convert_single_conv() with cuda or opencl:

    kernel[co][ci][r_kh][r_kw] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw]
  File "/home/masa/projects/dev/tvm/python/tvm/tir/expr.py", line 77, in __mul__
    return _generic.multiply(self, other)
  File "/home/masa/projects/dev/tvm/python/tvm/topi/generic_op_impl.py", line 83, in _tensor_bop_impl
    return orig_bop(lhs, rhs)
  File "/home/masa/projects/dev/tvm/python/tvm/tir/generic.py", line 86, in multiply
    return _ffi_api._OpMul(lhs, rhs, span)
  File "/home/masa/projects/dev/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  3: TVMFuncCall
  2: std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::runtime::TypedPackedFunc<tvm::PrimExpr (tvm::PrimExpr, tvm::PrimExpr, tvm::Span)>::AssignTypedLambda<tvm::{lambda(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)#5}>(tvm::{lambda(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)#5}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)
  1: tvm::mul(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
  0: tvm::BinaryOpMatchTypes(tvm::PrimExpr&, tvm::PrimExpr&, tvm::Span)
  File "../src/tir/op/op.cc", line 144
TVMError: Cannot match type float16 vs float32

@AndrewZhaoLuo
Copy link
Contributor Author

I suspect this has to do with the schedule not actually supporting accumulation dtypes.

Can you post the rest of the trace?

@AndrewZhaoLuo
Copy link
Contributor Author

Hmm yeah the problem has to do with what i say. Specifically in python/tvm/topi/cuda/conv2d_winograd.py the winograd matrix G is cast to the output dtypes while the kernel isn't so there is a type mismatch.

In general it seems reasonable to have implicit type promotion to higher bit floating point types. Furthermore, it might also be good to have most binary arithmetic ops to have output_dtypes.

E.g. right now there isn't a good way to represent adding two fp16 numbers into a fp32 result.

Later NVidia GPUs support this as a more primitive operations so maybe we should have a better representation.

@AndrewZhaoLuo
Copy link
Contributor Author

AndrewZhaoLuo commented Jun 22, 2021

I'm just going to turn off accumulating to fp32 for now. I don't want to manually look at every single schedule ever written to check for correctness.

Turning things off all the unit tests except one pass. The one that doesn't pass is the problem described by @Lunderberg. This ones seems trickier since I don't understand how cuda codegen works at all:

E               rt_mod_dev = codegen.build_module(mod_dev, target) if len(mod_dev.functions) != 0 else None
E             File "/home/aluo/tvm/python/tvm/target/codegen.py", line 39, in build_module
E               return _ffi_api.Build(mod, target)
E             File "/home/aluo/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
E               raise get_last_ffi_error()
E             19: TVMFuncCall
E             18: void tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::Target), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<c
har> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
E             17: tvm::codegen::Build(tvm::IRModule, tvm::Target)
E             16: void tvm::runtime::TypedPackedFunc<tvm::runtime::Module (tvm::IRModule, tvm::Target)>::AssignTypedLambda<tvm::runtime::Module (*)(tvm::IRModule, tvm::Target)>(tvm::runtime::Module (*)(tvm::IRModule, tvm::Target), std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<c
har> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
E             15: tvm::codegen::BuildCUDA(tvm::IRModule, tvm::Target)
E             14: tvm::codegen::CodeGenC::AddFunction(tvm::tir::PrimFunc const&)
E             13: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const
E             12: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
E             11: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
E             10: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const
E             9: tvm::codegen::CodeGenCUDA::VisitStmt_(tvm::tir::AttrStmtNode const*)
E             8: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::AttrStmtNode const*)
E             7: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const
E             6: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::IfThenElseNode const*)
E             5: tvm::NodeFunctor<void (tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::tir::StmtFunctor<void (tvm::tir::Stmt const&)>*) const
E             4: tvm::codegen::CodeGenC::VisitStmt_(tvm::tir::StoreNode const*)
E             3: tvm::codegen::CodeGenC::PrintExpr[abi:cxx11](tvm::PrimExpr const&)
E             2: tvm::codegen::CodeGenC::PrintExpr(tvm::PrimExpr const&, std::ostream&)
E             1: tvm::codegen::CodeGenCUDA::VisitExpr_(tvm::tir::BroadcastNode const*, std::ostream&)
E             0: tvm::codegen::CodeGenCUDA::PrintType(tvm::runtime::DataType, std::ostream&)
E             File "/home/aluo/tvm/src/target/source/codegen_cuda.cc", line 149
E           TVMError: 
E           ---------------------------------------------------------------
E           An error occurred during the execution of TVM.
E           For more information, please see: https://tvm.apache.org/docs/errors.html
E           ---------------------------------------------------------------
E           
E             Check failed: lanes % 2 == 0 (1 vs. 0) : only support even lane for half type

@AndrewZhaoLuo
Copy link
Contributor Author

Yeah so the failing test with this error tests/python/relay/test_to_mixed_precision.py works if you change the input size of the model to a number divisible by 2. Probably shouldn't fail so brittlely...

@Lunderberg
Copy link
Contributor

From what I can tell, the float16 values are packed into uint32 when not in use, and are cast to the float16 when used. I think there will need to be some special handling to pad out the calls to make_uintN (e.g. make_uint2 for a length-3 array of float16s packed into 64-bits with the last 16 bits empty), or the packing will need to change. Either way, looks like there isn't an immediate fix, and I don't know the cuda codegen as well as the vulkan side to say whether other issues would come up as well.

@AndrewZhaoLuo
Copy link
Contributor Author

Yep, not familiar with CUDA codegen either. I can't seem to trivially find the ops which cause this. Looks like another TODO.

@junrushao1994 do you have any ideas?

@AndrewZhaoLuo AndrewZhaoLuo changed the title CUDA support for mixed precision pass [AMP] CUDA support for mixed precision pass Jun 25, 2021
@AndrewZhaoLuo
Copy link
Contributor Author

With PR #8341 we can tune some models. Results here : https://docs.google.com/spreadsheets/d/12lgyfuHaRS-X4uG-1iQOV8oAuPpuVAbspcmkOSPRFHQ

We see good speedups, esp. for BERT.

@masahi
Copy link
Member

masahi commented Sep 24, 2021

I finally finished collecting data on FP16 performance using Tensocore. Since the NHWC conv2d tensorcore schedule requires a batch size multiple of at least 8, all batch sizes are 8. The speed up over FP32 (ansor), which is a strong baseline, is mixed. I expected better performance from tensorcore, but I guess our tensorcore schedules have a room for improvement (also hit a lot of errors when tuning tensorcore schedules, due to invalid schedules).

In most cases, we are much slower than TensorRT (not sure if TensorRT deeplabv3 number is a bit off or not).

All numbers in milli seconds and measured on RTX 3070. All models are in the NHWC layout.

Model name Input size FP32 (Auto scheduler, no tensorcore) FP16 (AutoTVM, tensorcore) FP16 TensorRT
resnet50 (8, 3, 224, 224) 8.61 4.14 2.53
efficientnet_v2 (8, 3, 224, 224) 21.6 13.2 5.25
YOLOv5l (8, 3, 512, 512) 32.4 13.22 NA
DETR-R50 (8, 3, 800, 750) 108.3 80.5 NA
deeplabv3_mobilenet_v3_large (8, 3, 512, 512) 22.6 15.9 19.2
bert_large (8, 128) 109.9 24.2 14.0

@masahi
Copy link
Member

masahi commented Sep 24, 2021

I think we can close this now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants