From b6443bb7ffc9dda41256aebfa400b556998b6cdd Mon Sep 17 00:00:00 2001 From: Yuzhong Wang Date: Tue, 26 Aug 2025 20:13:14 -0700 Subject: [PATCH 1/7] fix memory overhead of all gather from sequence parallel Signed-off-by: Yuzhong Wang --- .../pytorch/module/layernorm_linear.py | 22 +++++++++++++++++-- transformer_engine/pytorch/module/linear.py | 10 +++++++++ .../_internal/float8_blockwise_tensor_base.py | 3 +++ 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 04e3eba7da..4ac00f964c 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -353,8 +353,14 @@ def forward( # Deallocate GEMM input tensor if no longer needed if not weight.requires_grad and not return_layernorm_output: - ln_out = ln_out_total = None clear_tensor_data(ln_out, ln_out_total) + ln_out = ln_out_total = None + elif ( + ln_out_total is not ln_out_return + and ln_out_total is not ln_out + ): + clear_tensor_data(ln_out_total) + ln_out_total = None # ------------------------------------------------------ # Prepare output tensor @@ -892,7 +898,19 @@ def wgrad_gemm( del grad_bias_ # Deallocate input tensor if permitted - if not ctx.return_layernorm_output: + if ( + not ctx.return_layernorm_output + and not ctx.return_layernorm_output_gathered + ): + # Do not need to return layernorm output + clear_tensor_data(ln_out) + elif ( + ctx.return_layernorm_output_gathered + and ctx.ln_out_needs_gather + ): + # ln_out is not the returned tensor + clear_tensor_data(ln_out) + if ctx.ln_out_needs_gather: clear_tensor_data(ln_out_total) # Update grad input if overlapping reduce-scatter with wgrad GEMM diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 695cbb4e61..513353667a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -317,6 +317,11 @@ def forward( # Finished forward GEMM... # ------------------------------------------------------ + # Deallocate GEMM input tensor if no longer needed + if with_input_all_gather_nccl: + clear_tensor_data(inputmat_total) + inputmat_total = None + # ------------------------------------------------------ # Prepare output tensor # Note: Perform tensor-parallel communication @@ -881,6 +886,11 @@ def wgrad_gemm( # Deallocate input tensor if permitted if ctx.owns_input: clear_tensor_data(inputmat_total) + elif ctx.backward_input_needs_gather: + clear_tensor_data(inputmat_total) + + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index adffe7c580..2c7a54011b 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -349,9 +349,12 @@ def _create_columnwise(self): def _transpose_columnwise_data(self): """Plainly transpose the columnwise data and scale inv.""" if self._columnwise_data is not None: + _old_data = self._columnwise_data self._columnwise_data = tex.fp8_transpose( self._columnwise_data, self._fp8_dtype, out=None ) + _old_data.data = _empty_tensor() + del _old_data def __repr__(self): if self._rowwise_data is not None: From 743406beb579f6d5ae139c00fe447c17a8d7ac8a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Sep 2025 10:21:28 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/module/layernorm_linear.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4ac00f964c..74706d8ca3 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -355,10 +355,7 @@ def forward( if not weight.requires_grad and not return_layernorm_output: clear_tensor_data(ln_out, ln_out_total) ln_out = ln_out_total = None - elif ( - ln_out_total is not ln_out_return - and ln_out_total is not ln_out - ): + elif ln_out_total is not ln_out_return and ln_out_total is not ln_out: clear_tensor_data(ln_out_total) ln_out_total = None @@ -898,16 +895,10 @@ def wgrad_gemm( del grad_bias_ # Deallocate input tensor if permitted - if ( - not ctx.return_layernorm_output - and not ctx.return_layernorm_output_gathered - ): + if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: # Do not need to return layernorm output clear_tensor_data(ln_out) - elif ( - ctx.return_layernorm_output_gathered - and ctx.ln_out_needs_gather - ): + elif ctx.return_layernorm_output_gathered and ctx.ln_out_needs_gather: # ln_out is not the returned tensor clear_tensor_data(ln_out) if ctx.ln_out_needs_gather: From ce0a634c93482dd101f1082d5a5b8956f9167fa7 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 15 Sep 2025 13:31:13 -0700 Subject: [PATCH 3/7] Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../pytorch/tensor/_internal/float8_blockwise_tensor_base.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 2c7a54011b..da0220eb7a 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -349,6 +349,8 @@ def _create_columnwise(self): def _transpose_columnwise_data(self): """Plainly transpose the columnwise data and scale inv.""" if self._columnwise_data is not None: + # TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. _old_data = self._columnwise_data self._columnwise_data = tex.fp8_transpose( self._columnwise_data, self._fp8_dtype, out=None From 9ff7a559dc802e2e400a140ab44ea46340f071be Mon Sep 17 00:00:00 2001 From: Yuzhong Wang Date: Mon, 15 Sep 2025 20:15:17 -0700 Subject: [PATCH 4/7] quick fix the errors when for UB buffers Signed-off-by: Yuzhong Wang --- .../pytorch/module/layernorm_linear.py | 12 ++++++++---- transformer_engine/pytorch/module/linear.py | 9 +++++++-- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 1edea68ae0..3696e47b02 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -355,9 +355,12 @@ def forward( if not weight.requires_grad and not return_layernorm_output: clear_tensor_data(ln_out, ln_out_total) ln_out = ln_out_total = None - elif ln_out_total is not ln_out_return and ln_out_total is not ln_out: - clear_tensor_data(ln_out_total) - ln_out_total = None + else: + _ln_out_all_gather_nccl = with_input_all_gather and not ub_overlap_ag_fprop + if _ln_out_all_gather_nccl and not return_layernorm_output_gathered: + # ln_out_total is gathered by NCCL and not needed to be returned + clear_tensor_data(ln_out_total) + ln_out_total = None # ------------------------------------------------------ # Prepare output tensor @@ -901,7 +904,8 @@ def wgrad_gemm( elif ctx.return_layernorm_output_gathered and ctx.ln_out_needs_gather: # ln_out is not the returned tensor clear_tensor_data(ln_out) - if ctx.ln_out_needs_gather: + if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: + # ln_out_total is gathered by NCCL clear_tensor_data(ln_out_total) # Update grad input if overlapping reduce-scatter with wgrad GEMM diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f138172f47..f1ce2dd4e0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -886,10 +886,15 @@ def wgrad_gemm( # Deallocate input tensor if permitted if ctx.owns_input: clear_tensor_data(inputmat_total) - elif ctx.backward_input_needs_gather: + elif ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: + # inputmat_total is all gathered by NCCL clear_tensor_data(inputmat_total) - if ctx.parallel_mode == "row" and ctx.sequence_parallel: + _grad_output_all_gather_nccl = ( + (ctx.parallel_mode == "row" and ctx.sequence_parallel) # grad output is all gathered + and not ctx.ub_overlap_ag # all gathered by NCCL + ) + if _grad_output_all_gather_nccl: clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM From b3b6a0c678e60c7921030800eac50ca94ab84f39 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Sep 2025 03:15:58 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/linear.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f1ce2dd4e0..07a03a1de0 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -891,8 +891,10 @@ def wgrad_gemm( clear_tensor_data(inputmat_total) _grad_output_all_gather_nccl = ( - (ctx.parallel_mode == "row" and ctx.sequence_parallel) # grad output is all gathered - and not ctx.ub_overlap_ag # all gathered by NCCL + ( + ctx.parallel_mode == "row" and ctx.sequence_parallel + ) # grad output is all gathered + and not ctx.ub_overlap_ag # all gathered by NCCL ) if _grad_output_all_gather_nccl: clear_tensor_data(grad_output) From 70c84bc04e4cf5feef43a1125e19dee3e2010b1a Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 16 Sep 2025 11:38:55 -0700 Subject: [PATCH 6/7] Update transformer_engine/pytorch/module/linear.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/module/linear.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 07a03a1de0..6cbf394824 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -318,6 +318,8 @@ def forward( # ------------------------------------------------------ # Deallocate GEMM input tensor if no longer needed + # TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. if with_input_all_gather_nccl: clear_tensor_data(inputmat_total) inputmat_total = None From 8f57b1283a2ebdc37dcc0b303f88560474717692 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 17 Sep 2025 02:55:01 +0000 Subject: [PATCH 7/7] Avoid deallocating FP8 scale-invs since they are reused Signed-off-by: Tim Moon --- .../pytorch/module/layernorm_linear.py | 28 ++++++++++--------- transformer_engine/pytorch/module/linear.py | 17 ++++------- .../tensor/_internal/float8_tensor_base.py | 9 ++++-- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 3696e47b02..eb1a603646 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -355,12 +355,9 @@ def forward( if not weight.requires_grad and not return_layernorm_output: clear_tensor_data(ln_out, ln_out_total) ln_out = ln_out_total = None - else: - _ln_out_all_gather_nccl = with_input_all_gather and not ub_overlap_ag_fprop - if _ln_out_all_gather_nccl and not return_layernorm_output_gathered: - # ln_out_total is gathered by NCCL and not needed to be returned - clear_tensor_data(ln_out_total) - ln_out_total = None + elif with_input_all_gather and not return_layernorm_output_gathered: + clear_tensor_data(ln_out_total) + ln_out_total = None # ------------------------------------------------------ # Prepare output tensor @@ -897,16 +894,19 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted + # Deallocate input tensors if permitted if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: - # Do not need to return layernorm output + # Input tensors have not been exposed externally clear_tensor_data(ln_out) - elif ctx.return_layernorm_output_gathered and ctx.ln_out_needs_gather: - # ln_out is not the returned tensor + elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered: + # Non-gathered input has not been exposed externally clear_tensor_data(ln_out) - if ctx.ln_out_needs_gather and not ctx.ub_bulk_dgrad: - # ln_out_total is gathered by NCCL + if ctx.ln_out_needs_gather: + # Gathered input is internal clear_tensor_data(ln_out_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: @@ -1182,7 +1182,9 @@ def __init__( self.return_bias = return_bias self.apply_bias = self.use_bias and not return_bias self.return_layernorm_output = return_layernorm_output - self.return_layernorm_output_gathered = return_layernorm_output_gathered + self.return_layernorm_output_gathered = ( + return_layernorm_output_gathered if return_layernorm_output else False + ) self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 6cbf394824..838272b94b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -885,20 +885,15 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted + # Deallocate tensors if permitted if ctx.owns_input: + # Input tensor is internal clear_tensor_data(inputmat_total) - elif ctx.backward_input_needs_gather and not ctx.ub_bulk_dgrad: - # inputmat_total is all gathered by NCCL + elif ctx.backward_input_needs_gather: + # Gathered input tensor is internal clear_tensor_data(inputmat_total) - - _grad_output_all_gather_nccl = ( - ( - ctx.parallel_mode == "row" and ctx.sequence_parallel - ) # grad output is all gathered - and not ctx.ub_overlap_ag # all gathered by NCCL - ) - if _grad_output_all_gather_nccl: + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 61edc999ac..6d48223443 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -95,8 +95,13 @@ def __new__( return instance def clear(self): - """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" - for t in (self._data, self._transpose, self._scale_inv): + """Deallocate this tensor's memory. Typically not needed and must be used carefully. + + Scale-inv tensor is not deallocated because it's often shared + between multiple FP8 tensors. + + """ + for t in (self._data, self._transpose): if t is not None: t.data = _empty_tensor() self._transpose_invalid = True