Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,9 @@ def backward(
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,10 @@ def backward(
mu,
rsigma,
) = restore_from_saved(ctx.tensor_objects, saved_tensors)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
fc1_weight_main_grad = (
ctx.fc1_main_grad
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking
restore_from_saved(ctx.tensor_objects, saved_tensors)
)
# Delete the references to tensor objects once they've been consumed
# by the `restore_from_saved` method to construct back the actual tensors.
ctx.tensor_objects = None

# Since main_grad can be modified inplace, it should not be a part of saved_tensors
main_grad = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8Tensor

"""
tensors = [self._data, self._transpose]
# self._data = None
# self._transpose = None
self._data = None
self._transpose = None
Comment thread
ptrendx marked this conversation as resolved.
return tensors, self

def restore_from_saved(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorB

"""
tensors = [self._rowwise_data, self._columnwise_data]
# self._rowwise_data = None
# self._columnwise_data = None
self._rowwise_data = None
self._columnwise_data = None
return tensors, self

def restore_from_saved(
Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/pytorch/tensor/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,15 @@ def clear(self):
self._transpose = torch.Tensor() if self._transpose is not None else None
self._transpose_invalid = True

def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]:
"""Prepare the tensor base for saving for backward

After calling this, the tensor instance does not hold any
data.

"""
return [self], None

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):

Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,15 @@ def clear(self):
self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None
self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None

def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]:
"""Prepare the tensor base for saving for backward

After calling this, the tensor instance does not hold any
data.

"""
return [self], None

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):

Expand Down