-
Notifications
You must be signed in to change notification settings - Fork 3.3k
/
distributed.py
404 lines (326 loc) · 15.3 KB
/
distributed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities that can be used with distributed training."""
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch.nn.parallel.distributed import DistributedDataParallel
import pytorch_lightning as pl
from pytorch_lightning.utilities.imports import _HPU_AVAILABLE, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_debug as new_rank_zero_debug
from pytorch_lightning.utilities.rank_zero import rank_zero_only # noqa: F401
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.utilities.rank_zero import rank_zero_info as new_rank_zero_info
from pytorch_lightning.utilities.rank_zero import rank_zero_warn as new_rank_zero_warn
if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm
if torch.distributed.is_available():
from torch.distributed import group, ReduceOp
else:
class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153)
SUM = None
class group: # type: ignore
WORLD = None
log = logging.getLogger(__name__)
def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]:
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes.
Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)
Return:
gathered_result: list with size equal to the process group where
gathered_result[i] corresponds to result tensor from process i
"""
if group is None:
group = torch.distributed.group.WORLD
# convert tensors to contiguous format
result = result.contiguous()
world_size = torch.distributed.get_world_size(group)
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
# sync and broadcast all
torch.distributed.barrier(group=group)
torch.distributed.all_gather(gathered_result, result, group)
return gathered_result
def distributed_available() -> bool:
return torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed()
def sync_ddp_if_available(
result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""Function to reduce a tensor across worker processes during distributed training.
Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
Return:
reduced value
"""
if distributed_available():
return sync_ddp(result, group=group, reduce_op=reduce_op)
return result
def sync_ddp(
result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""Function to reduce the tensors from several ddp processes to one main process.
Args:
result: the value to sync and reduce (typically tensor or number)
group: the process group to gather results from. Defaults to all processes (world)
reduce_op: the reduction operation. Defaults to sum.
Can also be a string of 'avg', 'mean' to calculate the mean during reduction.
Return:
reduced value
"""
divide_by_world_size = False
if group is None:
group = torch.distributed.group.WORLD
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
op = ReduceOp.SUM
divide_by_world_size = True
else:
op = getattr(ReduceOp, reduce_op.upper())
else:
op = reduce_op
# WA for HPU. HPU doesn't support Long types, forcefully set it to float
if _HPU_AVAILABLE:
is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1"
if is_hpu_backend:
if (result.type() == "torch.LongTensor") or (result.type() == "torch.hpu.LongTensor"):
new_rank_zero_info("Long tensor unsupported on HPU, casting to float")
result = result.float()
# sync all processes before reduction
torch.distributed.barrier(group=group)
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
if divide_by_world_size:
result = result / torch.distributed.get_world_size(group)
return result
class AllGatherGrad(torch.autograd.Function):
@staticmethod
def forward(
ctx: Any,
tensor: torch.Tensor,
group: Optional["torch.distributed.ProcessGroup"] = group.WORLD,
) -> torch.Tensor:
ctx.group = group
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
gathered_tensor = torch.stack(gathered_tensor, dim=0)
return gathered_tensor
@staticmethod
def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
grad_output = torch.cat(grad_output)
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
return grad_output[torch.distributed.get_rank()], None
def all_gather_ddp_if_available(
tensor: torch.Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False
) -> torch.Tensor:
"""Function to gather a tensor from several distributed processes.
Args:
tensor: tensor of shape (batch, ...)
group: the process group to gather results from. Defaults to all processes (world)
sync_grads: flag that allows users to synchronize gradients for all_gather op
Return:
A tensor of shape (world_size, batch, ...)
"""
group = group if group is not None else torch.distributed.group.WORLD
if distributed_available():
if sync_grads:
return AllGatherGrad.apply(tensor, group)
with torch.no_grad():
return AllGatherGrad.apply(tensor, group)
return tensor
def register_ddp_comm_hook(
model: DistributedDataParallel,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[Callable] = None,
ddp_comm_wrapper: Optional[Callable] = None,
) -> None:
"""Function to register communication hook for DDP model https://pytorch.org/docs/master/ddp_comm_hooks.html.
Args:
model:
DDP model
ddp_comm_state:
state is passed to the hook and can be used to maintain
and update any state information that users would like to
maintain as part of the training process. Examples: error
feedback in gradient compression, peers to communicate with
next in GossipGrad etc.
ddp_comm_hook:
hook(state: object, bucket: dist._GradBucket) -> torch.futures.Future
This callable function is called once the bucket is ready. The
hook can perform whatever processing is needed and return
a Future indicating completion of any async work (ex: allreduce).
If the hook doesn't perform any communication, it can also
just return a completed Future. The Future should hold the
new value of grad bucket's tensors. Once a bucket is ready,
c10d reducer would call this hook and use the tensors returned
by the Future and copy grads to individual parameters.
ddp_comm_wrapper:
communication hook wrapper to support a communication hook such
as FP16 compression as wrapper, which could be combined with
ddp_comm_hook
.. warning ::
DDP communication hook needs pytorch version at least 1.8.0
.. warning ::
DDP communication wrapper needs pytorch version at least 1.9.0
Post-localSGD hook needs pytorch version at least 1.9.0
Examples:
>>> from torch.distributed.algorithms.ddp_comm_hooks import ( # doctest: +SKIP
... default_hooks as default,
... powerSGD_hook as powerSGD,
... post_localSGD_hook as post_localSGD,
... )
>>>
>>> # fp16_compress_hook for compress gradients
>>> ddp_model = ...
>>> register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_hook=default.fp16_compress_hook,
... )
>>>
>>> # powerSGD_hook
>>> ddp_model = ...
>>> register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_state=powerSGD.PowerSGDState(
... process_group=None,
... matrix_approximation_rank=1,
... start_powerSGD_iter=5000,
... ),
... ddp_comm_hook=powerSGD.powerSGD_hook,
... )
>>>
>>> # post_localSGD_hook
>>> subgroup, _ = torch.distributed.new_subgroups() # doctest: +SKIP
>>> ddp_model = ...
>>> register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... state=post_localSGD.PostLocalSGDState(
... process_group=None,
... subgroup=subgroup,
... start_localSGD_iter=1_000,
... ),
... ddp_comm_hook=post_localSGD.post_localSGD_hook,
... )
>>>
>>> # fp16_compress_wrapper combined with other communication hook
>>> ddp_model = ...
>>> register_ddp_comm_hook( # doctest: +SKIP
... model=ddp_model,
... ddp_comm_state=powerSGD.PowerSGDState(
... process_group=None,
... matrix_approximation_rank=1,
... start_powerSGD_iter=5000,
... ),
... ddp_comm_hook=powerSGD.powerSGD_hook,
... ddp_comm_wrapper=default.fp16_compress_wrapper,
... )
"""
if ddp_comm_hook is None:
return
# inform mypy that ddp_comm_hook is callable
ddp_comm_hook: Callable = ddp_comm_hook
if ddp_comm_wrapper is not None:
if not _TORCH_GREATER_EQUAL_1_9:
new_rank_zero_warn(
"Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0."
)
else:
new_rank_zero_info(
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
)
ddp_comm_hook = ddp_comm_wrapper(ddp_comm_hook)
new_rank_zero_debug(f"Registering DDP comm hook: {ddp_comm_hook.__qualname__}.")
model.register_comm_hook(state=ddp_comm_state, hook=ddp_comm_hook)
def tpu_distributed() -> bool:
return _TPU_AVAILABLE and xm.xrt_world_size() > 1
def get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"
def _get_process_group_backend_from_env() -> Optional[str]:
torch_backend = os.getenv("PL_TORCH_DISTRIBUTED_BACKEND")
if torch_backend is not None:
rank_zero_deprecation(
"Environment variable `PL_TORCH_DISTRIBUTED_BACKEND`"
" was deprecated in v1.6 and will be removed in v1.8."
" Specify `process_group_backend` directly on the strategy constructor."
)
return torch_backend
def init_dist_connection(
cluster_environment: "pl.plugins.environments.ClusterEnvironment",
torch_distributed_backend: str,
global_rank: Optional[int] = None,
world_size: Optional[int] = None,
**kwargs: Any,
) -> None:
"""Utility function to initialize distributed connection by setting env variables and initializing the
distributed process group.
Args:
cluster_environment: ``ClusterEnvironment`` instance
torch_distributed_backend: backend to use (includes `nccl` and `gloo`)
global_rank: rank of the current process
world_size: number of processes in the group
kwargs: kwargs for ``init_process_group``
Raises:
RuntimeError:
If ``torch.distributed`` is not available
"""
if not torch.distributed.is_available():
raise RuntimeError("torch.distributed is not available. Cannot initialize distributed process group")
if torch.distributed.is_initialized():
log.debug("torch.distributed is already initialized. Exiting early")
return
global_rank = global_rank if global_rank is not None else cluster_environment.global_rank()
world_size = world_size if world_size is not None else cluster_environment.world_size()
os.environ["MASTER_ADDR"] = cluster_environment.main_address
os.environ["MASTER_PORT"] = str(cluster_environment.main_port)
log.info(f"Initializing distributed: GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
torch.distributed.init_process_group(torch_distributed_backend, rank=global_rank, world_size=world_size, **kwargs)
# on rank=0 let everyone know training is starting
new_rank_zero_info(
f"{'-' * 100}\n"
f"distributed_backend={torch_distributed_backend}\n"
f"All distributed processes registered. Starting with {world_size} processes\n"
f"{'-' * 100}\n"
)
def _broadcast_object_list(obj: Any, rank: int) -> Any:
objects = [obj if torch.distributed.get_rank() == rank else None]
torch.distributed.broadcast_object_list(objects, src=rank)
return objects[0]
# TODO: Refactor with the Strategy Collectives once finalized.
def _collect_states_on_rank_zero(state: Dict[str, Any]) -> Dict[int, Any]:
"""This distributed utility collects dictionary state across all processes.
Args:
state: Dictionary containing the state of the current process
device: Current process device.
Returns:
states: On global rank 0, a dictionary where the primary keys are
the process rank and the values their associated states. Otherwise, returns None.
"""
if not distributed_available():
return {0: state}
return {rank: _broadcast_object_list(state, rank) for rank in range(torch.distributed.get_world_size())}
def rank_zero_info(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"pytorch_lightning.utilities.distributed.rank_zero_info has been deprecated in v1.6"
" and will be removed in v1.8."
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
)
return new_rank_zero_info(*args, **kwargs)
def rank_zero_debug(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"pytorch_lightning.utilities.distributed.rank_zero_debug has been deprecated in v1.6"
" and will be removed in v1.8."
" Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead."
)
return new_rank_zero_debug(*args, **kwargs)