diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 0f40e92183..c73f560565 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -4658,7 +4658,6 @@ def test_grouped_gemm_quant_cute_matches_mxfp8_quantized() -> None: norm_const_tensor=None, prob_tensor=inputs["prob_tensor"], acc_dtype=torch.float32, - c_dtype=torch.bfloat16, d_dtype=torch.bfloat16, cd_major="n", sf_vec_size=32, diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index aca49e9866..29273a5b47 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -687,7 +687,6 @@ def fuser_backward( "norm_const_tensor": None, "prob_tensor": torch.ones((out_shape[0], 1, 1), dtype=torch.float32, device=device), "acc_dtype": torch.float32, - "c_dtype": dtype, "d_dtype": dtype, "cd_major": "n", "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE, diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 90c4204f06..cad31e2c50 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -436,7 +436,6 @@ def fuser_forward( "norm_const_tensor": None, "prob_tensor": fc2_scales_tensor, "acc_dtype": torch.float32, - "c_dtype": dtype, "d_dtype": dtype, "cd_major": "n", "sf_vec_size": MXFP8_BLOCK_SCALING_SIZE,