Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable TORCH_NCCL_AVOID_RECORD_STREAMS=1 by default #512

Merged
merged 17 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions thunder/benchmarks/benchmark_litgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
local_rank = int(os.environ.get("LOCAL_RANK", 0))
global_rank = int(os.environ.get("RANK", 0))
if world_size > 1:
# Avoids the allocator thrashing issue in PyTorch NCCL backend.
# See https://github.com/Lightning-AI/lightning-thunder/issues/420
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
torch_dist.init_process_group(backend="nccl")
pg = torch_dist.distributed_c10d._get_default_group()
device = torch.device("cuda", local_rank)
Expand Down
43 changes: 40 additions & 3 deletions thunder/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,43 @@
_skip_data_parallel_grad_sync = ContextVar("skip_data_parallel_grad_sync", default=False)


def _avoid_torch_nccl_record_streams(func):
"""
Avoids the allocator thrashing issue in PyTorch NCCL backend.
"""

env_var = "TORCH_NCCL_AVOID_RECORD_STREAMS"
value = os.environ.get(env_var, "0")

def wrapper(*args, **kwargs):
try:
os.environ[env_var] = "1"
return func(*args, **kwargs)
finally:
os.environ[env_var] = value

return wrapper


@_avoid_torch_nccl_record_streams
def copy_default_process_group() -> ProcessGroup:
"""Create a new process group with the same ranks as the default process group.

Returns:
A new process group with the same ranks as the default process group.
"""
default_pg = tdist.distributed_c10d._get_default_group()
ranks = list(range(tdist.get_world_size(group=default_pg)))
backend = tdist.distributed_c10d.get_backend(default_pg)
# What's the better way to query this from the default process group? This
# is the default value for `is_high_priority_stream` in PyTorch
# default_pg.options returns ProcessGroup.Options object while
# ProcessGroupNCCL.Options is required
options = tdist.ProcessGroupNCCL.Options()
t-vi marked this conversation as resolved.
Show resolved Hide resolved
options.is_high_priority_stream = False
return tdist.new_group(ranks, backend=backend, pg_options=options)


def set_skip_data_parallel_grad_sync(value: bool) -> Token:
"""Set whether to skip data parallel grad sync.

Expand Down Expand Up @@ -246,7 +283,7 @@ def main():
lambda: "ddp requires torch distributed to be available (but it's not)",
)

pg = tdist.distributed_c10d._get_default_group()
pg = copy_default_process_group()
utils.check(pg is not None, lambda: "The default process group is None")
model.use_ddp = True
model.process_group_for_ddp = pg
Expand Down Expand Up @@ -384,7 +421,7 @@ def fsdp_transform_module(
from thunder.core.module import ThunderModule
from thunder.distributed.transforms.fsdp_v2 import FSDPTraceTransform

process_group = tdist.distributed_c10d._get_default_group()
process_group = copy_default_process_group()
utils.check(process_group is not None, lambda: "The default process group is None")
global_rank = tdist.get_rank(group=process_group)
world_size = tdist.get_world_size(group=process_group)
Expand Down Expand Up @@ -549,7 +586,7 @@ def fsdp(
bucketing_strategy=bucketing_strategy,
)

process_group = tdist.distributed_c10d._get_default_group()
process_group = copy_default_process_group()
utils.check(process_group is not None, lambda: "The default process group is None")
model.use_fsdp = True
model.process_group_for_ddp = process_group
Expand Down
3 changes: 2 additions & 1 deletion thunder/distributed/tensor_parallel/column_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,12 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
from thunder.distributed import _shard_param
from thunder.core.transforms import add_transform
from thunder.core.module import ThunderModule
from thunder.distributed import copy_default_process_group

utils.check_type(thunder_module, ThunderModule)

if process_group is None:
process_group = distributed_c10d._get_default_group()
process_group = copy_default_process_group()
rank = distributed_c10d.get_rank(process_group)
world_size = distributed_c10d.get_world_size(process_group)

Expand Down
3 changes: 2 additions & 1 deletion thunder/distributed/tensor_parallel/row_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ def forward(self, tokens: torch.Tensor) -> torch.Tensor:
from thunder.distributed import _shard_param
from thunder.core.transforms import add_transform
from thunder.core.module import ThunderModule
from thunder.distributed import copy_default_process_group

utils.check_type(thunder_module, ThunderModule)

if process_group is None:
process_group = distributed_c10d._get_default_group()
process_group = copy_default_process_group()
rank = distributed_c10d.get_rank(process_group)
world_size = distributed_c10d.get_world_size(process_group)

Expand Down
Loading