Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,11 @@ def forward(

# Deallocate GEMM input tensor if no longer needed
if not weight.requires_grad and not return_layernorm_output:
ln_out = ln_out_total = None
clear_tensor_data(ln_out, ln_out_total)
ln_out = ln_out_total = None
elif with_input_all_gather and not return_layernorm_output_gathered:
clear_tensor_data(ln_out_total)
ln_out_total = None

# ------------------------------------------------------
# Prepare output tensor
Expand Down Expand Up @@ -891,9 +894,19 @@ def wgrad_gemm(
grad_bias = grad_bias_
del grad_bias_

# Deallocate input tensor if permitted
if not ctx.return_layernorm_output:
# Deallocate input tensors if permitted
if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered:
# Input tensors have not been exposed externally
clear_tensor_data(ln_out)
elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered:
# Non-gathered input has not been exposed externally
clear_tensor_data(ln_out)
if ctx.ln_out_needs_gather:
# Gathered input is internal
clear_tensor_data(ln_out_total)
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
# Gathered grad output tensor is internal
clear_tensor_data(grad_output)

# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
Expand Down Expand Up @@ -1169,7 +1182,9 @@ def __init__(
self.return_bias = return_bias
self.apply_bias = self.use_bias and not return_bias
self.return_layernorm_output = return_layernorm_output
self.return_layernorm_output_gathered = return_layernorm_output_gathered
self.return_layernorm_output_gathered = (
return_layernorm_output_gathered if return_layernorm_output else False
)
self.zero_centered_gamma = zero_centered_gamma
self.symmetric_ar_type = symmetric_ar_type

Expand Down
16 changes: 15 additions & 1 deletion transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,13 @@ def forward(
# Finished forward GEMM...
# ------------------------------------------------------

# Deallocate GEMM input tensor if no longer needed
# TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically
# deallocated by GC. Manually deallocating is a temporary hack.
if with_input_all_gather_nccl:
clear_tensor_data(inputmat_total)
inputmat_total = None

# ------------------------------------------------------
# Prepare output tensor
# Note: Perform tensor-parallel communication
Expand Down Expand Up @@ -878,9 +885,16 @@ def wgrad_gemm(
grad_bias = grad_bias_
del grad_bias_

# Deallocate input tensor if permitted
# Deallocate tensors if permitted
if ctx.owns_input:
# Input tensor is internal
clear_tensor_data(inputmat_total)
elif ctx.backward_input_needs_gather:
# Gathered input tensor is internal
clear_tensor_data(inputmat_total)
if ctx.parallel_mode == "row" and ctx.sequence_parallel:
# Gathered grad output tensor is internal
clear_tensor_data(grad_output)

# Update grad input if overlapping reduce-scatter with wgrad GEMM
if ctx.ub_bulk_wgrad:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,14 @@ def _create_columnwise(self):
def _transpose_columnwise_data(self):
"""Plainly transpose the columnwise data and scale inv."""
if self._columnwise_data is not None:
# TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically
# deallocated by GC. Manually deallocating is a temporary hack.
_old_data = self._columnwise_data
self._columnwise_data = tex.fp8_transpose(
self._columnwise_data, self._fp8_dtype, out=None
)
_old_data.data = _empty_tensor()
del _old_data
Comment on lines +354 to +359
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't Python refcounting deallocate the old data automatically? If not, then there's a larger bug in tex.fp8_transpose that we should fix. The code will become unmanageable if we have to apply this kind of trick everywhere we call a C++ extension.

I see that we return at::Tensor instead of Pybind11 objects. I wonder if it is not properly handling the refcount.

Copy link
Contributor Author

@yuzhongw-nvidia yuzhongw-nvidia Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a very good question that also confuses me. I have a similar question: why we need to use clear_tensor_data to release tensors that unused, but not let Python GC to deallocate them? I guess the root cause may be similar and my current fix is a similar trick with clear_tensor_data.

I'm not very familiar with the code and deeper implementation. Could you please share this problem with some TE experts to help solve it? Or do you think we could merge this PR first and try to find out the root cause later, because it is a little bit emergent for the runnability and perf of DSV3 / MLA long context training.

Copy link
Collaborator

@timmoon10 timmoon10 Sep 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We call clear_tensor_data to manually deallocate memory before Python GC. A good example is here in the LayerNormLinear backward:


The forward GEMM input tensor is stored within the autograd ctx, so GC will not deallocate until after the backward has finished. However, we don't need this buffer after the wgrad GEMM and ideally it would be reused for the LayerNorm grad.

Copy link
Contributor Author

@yuzhongw-nvidia yuzhongw-nvidia Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I described in this thread, I think it is not the business of C++ extension, but the business of Float8BlockQuantizer.make_empty (in this case, the unreleased tensor is also created by make_empty, and the C++ extension is only creating the new tensor). In other words, I believe the issue is: tensors created by Float8BlockQuantizer.make_empty can only be released by manually deleting _columnwise_data.data and _rowwise_data.data. (I believe that's why we need to delete the _columnwise_data.data and _rowwise_data.data in clear_tensor_data, but not to delete _columnwise_data and _rowwise_data.)


def __repr__(self):
if self._rowwise_data is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,13 @@ def __new__(
return instance

def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (self._data, self._transpose, self._scale_inv):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully.

Scale-inv tensor is not deallocated because it's often shared
between multiple FP8 tensors.

"""
for t in (self._data, self._transpose):
if t is not None:
t.data = _empty_tensor()
self._transpose_invalid = True
Expand Down