Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def forward(
weight.requires_grad and parallel_mode == "column" and sequence_parallel
)

# Input with column-wise usage is needed for dgrad GEMM.
# Input with column-wise usage is needed for wgrad GEMM.
if backward_needs_input:
if isinstance(ln_out, QuantizedTensor):
# For sequence parallel in vanilla FP8, rowwise data is
Expand All @@ -357,6 +357,11 @@ def forward(
if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
ln_out.update_usage(rowwise_usage=False)

# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True)

if cpu_offloading:
if fp8 and weightmat is not None:
set_offloading_param(weightmat, "weight_offloading", True)
Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,14 @@ def forward(
ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None,
extra_output=rs_out,
)

# Weight with column-wise usage is needed for dgrad GEMM.
if is_grad_enabled and inp.requires_grad:
if isinstance(fc1_weight_final, QuantizedTensor):
fc1_weight_final.update_usage(columnwise_usage=True)
if isinstance(fc2_weight_final, QuantizedTensor):
fc2_weight_final.update_usage(columnwise_usage=True)

if not is_grad_enabled:
clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
else:
Expand Down
5 changes: 5 additions & 0 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ def forward(
inputmat.update_usage(rowwise_usage=False)
saved_inputmat = inputmat

# Weight with column-wise usage is needed for dgrad GEMM.
if inp.requires_grad:
if isinstance(weightmat, QuantizedTensor):
weightmat.update_usage(columnwise_usage=True)

if cpu_offloading:
set_offloading_param(weight, "weight_offloading", True)
set_offloading_param(weightmat, "weight_offloading", True)
Expand Down