From fd76d53ede43a5b92f6fe70ff839ea4f12413ef6 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Fri, 21 Feb 2025 16:27:56 -0800 Subject: [PATCH 1/3] Fix a crash with module._apply(lambda t: t.cpu()) Signed-off-by: Guyue Huang --- transformer_engine/pytorch/tensor/float8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index da788182a0..cb36a103dc 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -481,7 +481,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: casts to FP8. """ - + torch.cuda.synchronize() # Tensor device new_device = tensor.device if tensor.is_cuda else self.device From f6ccf8099da8e7aec4dc0498358f2ddeff92fd40 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Sat, 22 Feb 2025 12:04:50 -0800 Subject: [PATCH 2/3] Add comments Signed-off-by: Guyue Huang --- transformer_engine/pytorch/tensor/float8_tensor.py | 3 +++ transformer_engine/pytorch/tensor/mxfp8_tensor.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index cb36a103dc..7150f9e469 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -481,7 +481,10 @@ def _set_data(self, tensor: torch.Tensor) -> None: casts to FP8. """ + + # Synchronize to avoid a race condition torch.cuda.synchronize() + # Tensor device new_device = tensor.device if tensor.is_cuda else self.device diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 86b13415a1..3ee9d6d509 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -366,6 +366,9 @@ def _set_data(self, tensor: torch.Tensor) -> None: """ + # Synchronize to avoid a race condition + torch.cuda.synchronize() + # Tensor device new_device = tensor.device if tensor.is_cuda else self.device From 2e35ba6f139571212aec6475ac68de8ceda66676 Mon Sep 17 00:00:00 2001 From: Guyue Huang Date: Mon, 24 Feb 2025 13:31:54 -0800 Subject: [PATCH 3/3] Make sure tensor is moved to dst device before quantizer quantizes Signed-off-by: Guyue Huang --- transformer_engine/pytorch/tensor/float8_tensor.py | 5 ++--- transformer_engine/pytorch/tensor/mxfp8_tensor.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 7150f9e469..989959817a 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -482,11 +482,10 @@ def _set_data(self, tensor: torch.Tensor) -> None: """ - # Synchronize to avoid a race condition - torch.cuda.synchronize() - # Tensor device new_device = tensor.device if tensor.is_cuda else self.device + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) # Just copy FP8 data if other tensor is Float8Tensor if isinstance(tensor, Float8Tensor): diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 3ee9d6d509..6e3835fbef 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -366,11 +366,10 @@ def _set_data(self, tensor: torch.Tensor) -> None: """ - # Synchronize to avoid a race condition - torch.cuda.synchronize() - # Tensor device new_device = tensor.device if tensor.is_cuda else self.device + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) # Just copy FP8 data if other tensor is MXFP8Tensor if isinstance(tensor, MXFP8Tensor):