Skip to content

Commit

Permalink
Move HPU broadcast override to the HPU strategy file (#17011)
Browse files Browse the repository at this point in the history
(cherry picked from commit da6263a)
  • Loading branch information
carmocca authored and lantiga committed Apr 3, 2023
1 parent 94af83b commit 10774ca
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 172 deletions.
170 changes: 0 additions & 170 deletions src/pytorch_lightning/overrides/torch_distributed.py

This file was deleted.

80 changes: 78 additions & 2 deletions src/pytorch_lightning/strategies/hpu_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from lightning_fabric.utilities.distributed import group as _group
from pytorch_lightning.accelerators.hpu import _HPU_AVAILABLE
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand Down Expand Up @@ -136,7 +135,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: # type: ignore
if self.global_rank != src:
obj = [None]

broadcast_object_list(obj, src, group=_group.WORLD)
_hpu_broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

def on_after_backward(self) -> None:
Expand Down Expand Up @@ -179,3 +178,80 @@ def teardown(self) -> None:
# Was set to local rank
os.environ.pop("ID", None)
os.environ.pop("HCCL_DISTRIBUTED_BACKEND", None)


# The code underneath is taken from PyTorch `torch/distributed/distributed_c10d.py`
# the distributed backend and tensor type updates for habana backend is done here before broadcast
def _hpu_broadcast_object_list(object_list, src=0, group=None, device=None): # type: ignore
from torch.distributed import _rank_not_in_group, Backend, broadcast, get_backend, get_rank
from torch.distributed.distributed_c10d import _object_to_tensor, _tensor_to_object

if _rank_not_in_group(group):
return

my_rank = get_rank()
# Serialize object_list elements to tensors on src rank.
if my_rank == src:
tensor_list, size_list = zip(*[_object_to_tensor(obj, device) for obj in object_list])
object_sizes_tensor = torch.cat(size_list)
else:
object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long)

# Current device selection.
# To preserve backwards compatibility, ``device`` is default to ``None``
# in which case we run current logic of device selection, i.e.
# ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
# case it is not ``None`` we move the size and object tensors to be
# broadcasted to this device.
group_backend = get_backend(group)
is_nccl_backend = group_backend == Backend.NCCL
is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
if device is not None:
if is_nccl_backend and device.type != "cuda":
raise ValueError("device type must be cuda for nccl backend")
current_device = device
else:
current_device = torch.device("cpu")
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in
# docstring. We cannot simply use my_rank since rank == device is
# not necessarily true.
current_device = torch.device("cuda", torch.cuda.current_device())
if is_nccl_backend:
object_sizes_tensor = object_sizes_tensor.to(current_device)

elif is_hpu_backend:
current_device = torch.device("hpu")
# Workaround: HPU doesn't not support long tensors for collectives
if (object_sizes_tensor.type() == "torch.LongTensor") or (object_sizes_tensor.type() == "torch.hpu.LongTensor"):
object_sizes_tensor = object_sizes_tensor.int()
else:
print("unhandled hpu object_sizes_tensor type :: ", object_sizes_tensor.type())
object_sizes_tensor = object_sizes_tensor.to(current_device)

# Broadcast object sizes
broadcast(object_sizes_tensor, src=src, group=group)

# Concatenate and broadcast serialized object tensors
if my_rank == src:
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.empty(
torch.sum(object_sizes_tensor).int().item(),
dtype=torch.uint8,
)

if is_nccl_backend or is_hpu_backend:
object_tensor = object_tensor.to(current_device)

broadcast(object_tensor, src=src, group=group)
# Deserialize objects using their stored sizes.
offset = 0
if my_rank != src:
for i, obj_size in enumerate(object_sizes_tensor):
obj_view = object_tensor[offset : offset + obj_size]
obj_view = obj_view.type(torch.uint8)
if obj_view.device != torch.device("cpu"):
obj_view = obj_view.cpu()
offset += obj_size
object_list[i] = _tensor_to_object(obj_view, obj_size)

0 comments on commit 10774ca

Please sign in to comment.