diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index f2c3701e..e93c9884 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -265,7 +265,7 @@ def get_fp8_config(cls, quantization_calibration_method: str): module_path=".*", # Apply to all modules weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, + bwd_qtype=jnp.float8_e4m3fn, bwd_use_original_residuals=True, disable_channelwise_axes=True, # per_tensor calibration weight_calibration_method=quantization_calibration_method, diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 84efa064..26ea0f02 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -316,7 +316,7 @@ def test_get_qt_provider(self, mock_qt_rule): module_path=".*", # Apply to all modules weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, + bwd_qtype=jnp.float8_e4m3fn, bwd_use_original_residuals=True, disable_channelwise_axes=True, # per_tensor calibration weight_calibration_method=config_fp8_full.quantization_calibration_method,