Skip to content

Commit

Permalink
Enable TORCH_NCCL_AVOID_RECORD_STREAMS=1 by default (#512)
Browse files Browse the repository at this point in the history
  • Loading branch information
IvanYashchuk committed Jun 5, 2024
1 parent 4cc7b64 commit 23da3c1
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 7 deletions.
3 changes: 3 additions & 0 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
# Avoids the allocator thrashing issue in PyTorch NCCL backend.
# See https://github.com/Lightning-AI/lightning-thunder/issues/420
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
torch_dist.init_process_group(backend="nccl")
pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
Expand Down
51 changes: 46 additions & 5 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,45 @@
_skip_data_parallel_grad_sync = ContextVar("skip_data_parallel_grad_sync", default=False)


def _avoid_torch_nccl_record_streams(func):
"""
Avoids the allocator thrashing issue in PyTorch NCCL backend.
"""

env_var = "TORCH_NCCL_AVOID_RECORD_STREAMS"
value = os.environ.get(env_var, "0")

def wrapper(*args, **kwargs):
try:
os.environ[env_var] = "1"
return func(*args, **kwargs)
finally:
os.environ[env_var] = value

return wrapper


@_avoid_torch_nccl_record_streams
def copy_default_process_group() -> ProcessGroup:
"""Create a new process group with the same ranks as the default process group.
Returns:
A new process group with the same ranks as the default process group.
"""
default_pg = tdist.distributed_c10d._get_default_group()
ranks = list(range(tdist.get_world_size(group=default_pg)))
backend = tdist.distributed_c10d.get_backend(default_pg)
# What's the better way to query this from the default process group? This
# is the default value for `is_high_priority_stream` in PyTorch
# default_pg.options returns ProcessGroup.Options object while
# ProcessGroupNCCL.Options is required
options = None
if backend == "nccl":
options = tdist.ProcessGroupNCCL.Options()
options.is_high_priority_stream = False
return tdist.new_group(ranks, backend=backend, pg_options=options)


def set_skip_data_parallel_grad_sync(value: bool) -> Token:
"""Set whether to skip data parallel grad sync.
Expand Down Expand Up @@ -102,7 +141,7 @@ def _sync_grads(module: torch.nn.Module) -> None:
torch._foreach_div_(grads, process_group.size())
with tdist.distributed_c10d._coalescing_manager(group=process_group, async_ops=True) as cm:
for g in grads:
tdist.distributed_c10d.all_reduce(g)
tdist.distributed_c10d.all_reduce(g, group=process_group)
cm.wait()
elif getattr(module, "use_fsdp", False):

Expand All @@ -123,7 +162,9 @@ def prep_shard(
sharded_grads = [prep_shard(g, rank, world_size) for g in unsharded_grads]
with tdist.distributed_c10d._coalescing_manager(group=process_group, async_ops=True) as cm:
for u, s in zip(unsharded_grads, sharded_grads):
tdist.distributed_c10d.reduce_scatter_tensor(s, u, op=tdist.distributed_c10d.ReduceOp.AVG)
tdist.distributed_c10d.reduce_scatter_tensor(
s, u, op=tdist.distributed_c10d.ReduceOp.AVG, group=process_group
)
cm.wait()
for p, g in zip(params_with_grad, sharded_grads):
p.grad = g
Expand Down Expand Up @@ -246,7 +287,7 @@ def main():
lambda: "ddp requires torch distributed to be available (but it's not)",
)

pg = tdist.distributed_c10d._get_default_group()
pg = copy_default_process_group()
utils.check(pg is not None, lambda: "The default process group is None")
model.use_ddp = True
model.process_group_for_ddp = pg
Expand Down Expand Up @@ -384,7 +425,7 @@ def fsdp_transform_module(
from thunder.core.module import ThunderModule
from thunder.distributed.transforms.fsdp_v2 import FSDPTraceTransform

process_group = tdist.distributed_c10d._get_default_group()
process_group = copy_default_process_group()
utils.check(process_group is not None, lambda: "The default process group is None")
global_rank = tdist.get_rank(group=process_group)
world_size = tdist.get_world_size(group=process_group)
Expand Down Expand Up @@ -549,7 +590,7 @@ def fsdp(
bucketing_strategy=bucketing_strategy,
)

process_group = tdist.distributed_c10d._get_default_group()
process_group = copy_default_process_group()
utils.check(process_group is not None, lambda: "The default process group is None")
model.use_fsdp = True
model.process_group_for_ddp = process_group
Expand Down
3 changes: 2 additions & 1 deletion thunder/distributed/tensor_parallel/column_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,11 +217,12 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
from thunder.distributed import _shard_param
from thunder.core.transforms import add_transform
from thunder.core.module import ThunderModule
from thunder.distributed import copy_default_process_group

utils.check_type(thunder_module, ThunderModule)

if process_group is None:
process_group = distributed_c10d._get_default_group()
process_group = copy_default_process_group()
rank = distributed_c10d.get_rank(process_group)
world_size = distributed_c10d.get_world_size(process_group)

Expand Down
3 changes: 2 additions & 1 deletion thunder/distributed/tensor_parallel/row_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,11 +221,12 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
from thunder.distributed import _shard_param
from thunder.core.transforms import add_transform
from thunder.core.module import ThunderModule
from thunder.distributed import copy_default_process_group

utils.check_type(thunder_module, ThunderModule)

if process_group is None:
process_group = distributed_c10d._get_default_group()
process_group = copy_default_process_group()
rank = distributed_c10d.get_rank(process_group)
world_size = distributed_c10d.get_world_size(process_group)

Expand Down

0 comments on commit 23da3c1

Please sign in to comment.