diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4d4d5ca78b..f4fe913d57 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -348,7 +348,7 @@ def forward( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) - # Input with column-wise usage is needed for dgrad GEMM. + # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: if isinstance(ln_out, QuantizedTensor): # For sequence parallel in vanilla FP8, rowwise data is @@ -357,6 +357,11 @@ def forward( if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather: ln_out.update_usage(rowwise_usage=False) + # Weight with column-wise usage is needed for dgrad GEMM. + if inp.requires_grad: + if isinstance(weightmat, QuantizedTensor): + weightmat.update_usage(columnwise_usage=True) + if cpu_offloading: if fp8 and weightmat is not None: set_offloading_param(weightmat, "weight_offloading", True) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index f20c95c0fc..6f6d0d98fd 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -448,6 +448,14 @@ def forward( ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None, extra_output=rs_out, ) + + # Weight with column-wise usage is needed for dgrad GEMM. + if is_grad_enabled and inp.requires_grad: + if isinstance(fc1_weight_final, QuantizedTensor): + fc1_weight_final.update_usage(columnwise_usage=True) + if isinstance(fc2_weight_final, QuantizedTensor): + fc2_weight_final.update_usage(columnwise_usage=True) + if not is_grad_enabled: clear_tensor_data(act_out, fc1_out_without_bias, fc1_out) else: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f96355a678..71d38e2822 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -274,6 +274,11 @@ def forward( inputmat.update_usage(rowwise_usage=False) saved_inputmat = inputmat + # Weight with column-wise usage is needed for dgrad GEMM. + if inp.requires_grad: + if isinstance(weightmat, QuantizedTensor): + weightmat.update_usage(columnwise_usage=True) + if cpu_offloading: set_offloading_param(weight, "weight_offloading", True) set_offloading_param(weightmat, "weight_offloading", True)