From 065661bf93bcb91890f74bca56ba12e1060593d9 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 20 Feb 2025 22:00:28 -0800 Subject: [PATCH 1/8] delete extra tensor objects after restoring float8 tensors Signed-off-by: Sudhakar Singh --- transformer_engine/pytorch/module/layernorm_linear.py | 3 +++ transformer_engine/pytorch/module/layernorm_mlp.py | 4 ++++ transformer_engine/pytorch/module/linear.py | 3 +++ .../pytorch/tensor/_internal/float8_tensor_base.py | 4 ++-- .../pytorch/tensor/_internal/mxfp8_tensor_base.py | 4 ++-- 5 files changed, 14 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 01bda64101..b7d6e21f2d 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -448,6 +448,9 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed + # by the `restor_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 88eebc8e6c..7372609822 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -567,6 +567,10 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed + # by the `restor_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None + # Since main_grad can be modified inplace, it should not be a part of saved_tensors fc1_weight_main_grad = ( ctx.fc1_main_grad diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e51513630f..d6b4bcd25c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -352,6 +352,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat, weight_fp8, weight, bias = ( # pylint: disable=unbalanced-tuple-unpacking restore_from_saved(ctx.tensor_objects, saved_tensors) ) + # Delete the references to tensor objects once they've been consumed + # by the `restor_from_saved` method to construct back the actual tensors. + ctx.tensor_objects = None # Since main_grad can be modified inplace, it should not be a part of saved_tensors main_grad = ( diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 6b816db3b5..8ae45c9375 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -105,8 +105,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8Tensor """ tensors = [self._data, self._transpose] - # self._data = None - # self._transpose = None + self._data = None + self._transpose = None return tensors, self def restore_from_saved( diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index d78bd55d9a..ea7fc3cf2f 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -100,8 +100,8 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorB """ tensors = [self._rowwise_data, self._columnwise_data] - # self._rowwise_data = None - # self._columnwise_data = None + self._rowwise_data = None + self._columnwise_data = None return tensors, self def restore_from_saved( From 01a34dbe0360b3e6721f44de61ae65477fc0f978 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Feb 2025 06:06:14 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7372609822..9c02b8da34 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -570,7 +570,7 @@ def backward( # Delete the references to tensor objects once they've been consumed # by the `restor_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None - + # Since main_grad can be modified inplace, it should not be a part of saved_tensors fc1_weight_main_grad = ( ctx.fc1_main_grad From a1120ea9cfe8739a7f2c00972e31ce25de897197 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Sat, 22 Feb 2025 17:56:34 -0800 Subject: [PATCH 3/8] nit fix Signed-off-by: Sudhakar Singh --- transformer_engine/pytorch/module/layernorm_linear.py | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index b7d6e21f2d..007821038f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -449,7 +449,7 @@ def backward( rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) # Delete the references to tensor objects once they've been consumed - # by the `restor_from_saved` method to construct back the actual tensors. + # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None # Since main_grad can be modified inplace, it should not be a part of saved_tensors diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7372609822..eb04380ebf 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -568,7 +568,7 @@ def backward( rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) # Delete the references to tensor objects once they've been consumed - # by the `restor_from_saved` method to construct back the actual tensors. + # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None # Since main_grad can be modified inplace, it should not be a part of saved_tensors diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index d6b4bcd25c..3fe6144050 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -353,7 +353,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], restore_from_saved(ctx.tensor_objects, saved_tensors) ) # Delete the references to tensor objects once they've been consumed - # by the `restor_from_saved` method to construct back the actual tensors. + # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None # Since main_grad can be modified inplace, it should not be a part of saved_tensors From e36b440a86cb5408e0df1b22f003328dabf59cb2 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 26 Feb 2025 16:18:25 -0800 Subject: [PATCH 4/8] fix the leak in float8tensor and mxfloat8tensor classes Signed-off-by: Sudhakar Singh --- transformer_engine/pytorch/tensor/float8_tensor.py | 9 +++++++++ transformer_engine/pytorch/tensor/mxfp8_tensor.py | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index da788182a0..fd3ab050cc 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -340,6 +340,15 @@ def clear(self): self._transpose = torch.Tensor() if self._transpose is not None else None self._transpose_invalid = True + # def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: + # """Prepare the tensor base for saving for backward + + # After calling this, the tensor instance does not hold any + # data. + + # """ + # return [self], None + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 86b13415a1..9dad2ae2ec 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -285,6 +285,15 @@ def clear(self): self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None + # def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: + # """Prepare the tensor base for saving for backward + + # After calling this, the tensor instance does not hold any + # data. + + # """ + # return [self], None + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): From 17bf57d65655fc34fa9cd897d63bff500ed7f173 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Feb 2025 00:18:49 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/float8_tensor.py | 2 +- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 9cc138ef96..37146cc2c6 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -356,7 +356,7 @@ def clear(self): # """ # return [self], None - + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 10ce6d51b6..ffd5883e9f 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -293,7 +293,7 @@ def clear(self): # """ # return [self], None - + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): From b3643cda31ed43728445a342ed1a3569c257fc39 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 27 Feb 2025 14:24:03 -0800 Subject: [PATCH 6/8] uncomment the fix Signed-off-by: Sudhakar Singh --- transformer_engine/pytorch/tensor/float8_tensor.py | 12 ++++++------ transformer_engine/pytorch/tensor/mxfp8_tensor.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index fd3ab050cc..ca22c0885e 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -340,14 +340,14 @@ def clear(self): self._transpose = torch.Tensor() if self._transpose is not None else None self._transpose_invalid = True - # def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: - # """Prepare the tensor base for saving for backward + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: + """Prepare the tensor base for saving for backward - # After calling this, the tensor instance does not hold any - # data. + After calling this, the tensor instance does not hold any + data. - # """ - # return [self], None + """ + return [self], None @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 9dad2ae2ec..f9454f9c6c 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -285,14 +285,14 @@ def clear(self): self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None - # def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: - # """Prepare the tensor base for saving for backward + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: + """Prepare the tensor base for saving for backward - # After calling this, the tensor instance does not hold any - # data. + After calling this, the tensor instance does not hold any + data. - # """ - # return [self], None + """ + return [self], None @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): From 0c10c5fe3bb3277ecad5e8e666f3ec3545132fb1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Feb 2025 22:25:30 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/tensor/float8_tensor.py | 2 +- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 2666470863..5434cfb2fc 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -356,7 +356,7 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8Tensor """ return [self], None - + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index afe908763f..075bcade92 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -293,7 +293,7 @@ def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8Tensor """ return [self], None - + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): From 0c3aa4354dc79bf2b28928fb89bbe7389eaee0b6 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Thu, 27 Feb 2025 14:30:33 -0800 Subject: [PATCH 8/8] fix lint Signed-off-by: Sudhakar Singh --- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index afe908763f..dc22b624bc 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -285,7 +285,7 @@ def clear(self): self._rowwise_data = torch.Tensor() if self._rowwise_data is not None else None self._columnwise_data = torch.Tensor() if self._columnwise_data is not None else None - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], Float8TensorBase]: + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: """Prepare the tensor base for saving for backward After calling this, the tensor instance does not hold any