Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1469,7 +1469,8 @@ def forward(
fwd_nominal_dtype = q.dtype
is_input_fp8 = isinstance(q, QuantizedTensorStorage)
is_output_fp8 = fp8_output
is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
_use_fp8_dpa_bwd = bool(int(os.getenv("NVTE_FP8_DPA_BWD", "1")))
is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd
# recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
Expand Down Expand Up @@ -2063,20 +2064,17 @@ def forward(
# prepare for return and ctx saves
out_fp8 = None
out_f16 = out.to(fwd_nominal_dtype)
if fp8 and (
is_output_fp8
or (
is_bwd_fp8
and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)
and not fp8_recipe.mxfp8()
)
if (fp8 and is_output_fp8) or (
is_bwd_fp8
and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)
and not fp8_recipe.mxfp8()
):
out_fp8 = O_quantizer(out_f16)
out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16

ctx.layer_number = layer_number
ctx.fp8_recipe = fp8_recipe
ctx.fp8 = fp8 and is_bwd_fp8
ctx.fp8 = is_bwd_fp8

kv_fp8 = None
kv = p2p_comm_buffers[-1]
Expand Down Expand Up @@ -3063,7 +3061,8 @@ def forward(
), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage."
is_input_fp8 = isinstance(q, QuantizedTensorStorage)
is_output_fp8 = fp8_output
is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
_use_fp8_dpa_bwd = bool(int(os.getenv("NVTE_FP8_DPA_BWD", "1")))
is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None:
fp8_recipe = fp8_meta["local_recipes"][0]
Expand Down Expand Up @@ -3306,12 +3305,12 @@ def forward(
or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16)
)
)
if fp8 and (is_output_fp8 or bwd_requires_o_fp8):
if (fp8 and is_output_fp8) or bwd_requires_o_fp8:
out_fp8 = O_quantizer(out_f16)
out_ret = out_fp8 if is_output_fp8 else out_f16

# save tensors for backward
ctx.fp8 = fp8 and is_bwd_fp8
ctx.fp8 = is_bwd_fp8
ctx.fp8_recipe = fp8_recipe
fp8_tensors = (None, None, None, None)
f16_tensors = (None, None, None, None)
Expand Down Expand Up @@ -3931,7 +3930,8 @@ def forward(
), "q, k, v must be of the same class, e.g. torch.Tensor or QuantizedTensorStorage."
is_input_fp8 = isinstance(q, QuantizedTensorStorage)
is_output_fp8 = fp8_output
is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1"))
_use_fp8_dpa_bwd = bool(int(os.getenv("NVTE_FP8_DPA_BWD", "1")))
is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd
# recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE;
# may be different from fp8_meta["recipe"]
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
Expand Down Expand Up @@ -4161,7 +4161,7 @@ def forward(
ctx.orig_o_shape = orig_o_shape

# save tensors for backward
ctx.fp8 = fp8 and is_bwd_fp8
ctx.fp8 = is_bwd_fp8
fp8_tensors = (None, None, None, None)
f16_tensors = (None, None, None, None)
if is_training:
Expand Down
Loading