diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index cd02f31132..eb1a603646 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -353,8 +353,11 @@ def forward( # Deallocate GEMM input tensor if no longer needed if not weight.requires_grad and not return_layernorm_output: - ln_out = ln_out_total = None clear_tensor_data(ln_out, ln_out_total) + ln_out = ln_out_total = None + elif with_input_all_gather and not return_layernorm_output_gathered: + clear_tensor_data(ln_out_total) + ln_out_total = None # ------------------------------------------------------ # Prepare output tensor @@ -891,9 +894,19 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted - if not ctx.return_layernorm_output: + # Deallocate input tensors if permitted + if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: + # Input tensors have not been exposed externally + clear_tensor_data(ln_out) + elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered: + # Non-gathered input has not been exposed externally + clear_tensor_data(ln_out) + if ctx.ln_out_needs_gather: + # Gathered input is internal clear_tensor_data(ln_out_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: @@ -1169,7 +1182,9 @@ def __init__( self.return_bias = return_bias self.apply_bias = self.use_bias and not return_bias self.return_layernorm_output = return_layernorm_output - self.return_layernorm_output_gathered = return_layernorm_output_gathered + self.return_layernorm_output_gathered = ( + return_layernorm_output_gathered if return_layernorm_output else False + ) self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2ce6fb4c1d..838272b94b 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -317,6 +317,13 @@ def forward( # Finished forward GEMM... # ------------------------------------------------------ + # Deallocate GEMM input tensor if no longer needed + # TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. + if with_input_all_gather_nccl: + clear_tensor_data(inputmat_total) + inputmat_total = None + # ------------------------------------------------------ # Prepare output tensor # Note: Perform tensor-parallel communication @@ -878,9 +885,16 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted + # Deallocate tensors if permitted if ctx.owns_input: + # Input tensor is internal + clear_tensor_data(inputmat_total) + elif ctx.backward_input_needs_gather: + # Gathered input tensor is internal clear_tensor_data(inputmat_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index adffe7c580..da0220eb7a 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -349,9 +349,14 @@ def _create_columnwise(self): def _transpose_columnwise_data(self): """Plainly transpose the columnwise data and scale inv.""" if self._columnwise_data is not None: + # TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. + _old_data = self._columnwise_data self._columnwise_data = tex.fp8_transpose( self._columnwise_data, self._fp8_dtype, out=None ) + _old_data.data = _empty_tensor() + del _old_data def __repr__(self): if self._rowwise_data is not None: diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 61edc999ac..6d48223443 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -95,8 +95,13 @@ def __new__( return instance def clear(self): - """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" - for t in (self._data, self._transpose, self._scale_inv): + """Deallocate this tensor's memory. Typically not needed and must be used carefully. + + Scale-inv tensor is not deallocated because it's often shared + between multiple FP8 tensors. + + """ + for t in (self._data, self._transpose): if t is not None: t.data = _empty_tensor() self._transpose_invalid = True