[PyTorch] Add support for FP8 current scaling in operation-based API#1858
[PyTorch] Add support for FP8 current scaling in operation-based API#1858ksivaman merged 11 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
…r by default Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch L1 |
| if isinstance( | ||
| input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) | ||
| ): |
There was a problem hiding this comment.
| if isinstance( | |
| input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) | |
| ): | |
| if input_quantizer._get_compatible_recipe().float8_per_tensor_scaling(): |
There was a problem hiding this comment.
I think this fix adds overhead and makes the code less clear. _get_compatible_recipe (added in #1724) returns a type since it's used for type comparisons. If we want to use float8_per_tensor_scaling, then we need to instantiate a new DelayedScaling or Float8CurrentScaling that doesn't match the current recipe. Alternatively, we can compare classes with is or isinstance, which is basically what we're already doing.
| if isinstance( | ||
| weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) | ||
| ) and isinstance(weight, Float8TensorBase): |
There was a problem hiding this comment.
| if isinstance( | |
| weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) | |
| ) and isinstance(weight, Float8TensorBase): | |
| if weight_quantizer._get_compatible_recipe().float8_per_tensor_scaling() and | |
| isinstance(weight, Float8TensorBase): |
There was a problem hiding this comment.
I don't think _get_compatible_recipe benefits us here.
|
|
||
| # Recipe-specific configuration | ||
| recipe = FP8GlobalStateManager.get_fp8_recipe() | ||
| if recipe.float8_current_scaling(): |
There was a problem hiding this comment.
Note: This being required here is unfortunate. Most of what I would suggest here are probably out of scope for this PR.
There was a problem hiding this comment.
I don't think there's a good way around this. The amax reduction logic is specific to linear layers with FP8 current scaling, so there's no other logical place to put it.
| elif quantization == "fp8_current_scaling": | ||
| quantizer = Float8CurrentScalingQuantizer( | ||
| fp8_dtype=tex.DType.kFloat8E4M3, | ||
| device=test_device, |
There was a problem hiding this comment.
Orthogonal but why does this device arg exist here?
There was a problem hiding this comment.
Mostly for completeness. We have a ref_device option since I prefer computing the reference impl on CPU, which helps catch CUDA-related bugs.
|
|
||
| # Import utility functions | ||
| _current_file = pathlib.Path(__file__).resolve() | ||
| sys.path.append(str(_current_file.parent)) |
There was a problem hiding this comment.
Not a fan of this. But I see we currently do this for other tests as well so for this PR it's ok. In a separate refactor we should convert the tests dir to also be a module.
There was a problem hiding this comment.
I agree. Ideally we would use relative imports like from .utils import dtype_tols, but Python doesn't allow that if you are running the script directly, which is how Pytest runs the tests.
5e2490c to
0496d1f
Compare
|
/te-ci pytorch L1 |
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch L1 |
Description
This PR makes some minor changes needed for te.Sequential to support FP8 current scaling. It also modifies the te.Sequential unit tests to more gracefully handle quantization schemes other than FP8 delayed scaling.
Type of change
Changes
Checklist: