Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable TORCH_NCCL_AVOID_RECORD_STREAMS=1 by default #512

Merged
merged 17 commits into from
Jun 5, 2024

Conversation

IvanYashchuk
Copy link
Collaborator

@IvanYashchuk IvanYashchuk commented Jun 3, 2024

This PR enables the magic environment variable to be on by default for Thunder. The change should be restricted only to Thunder.
This magick environment variable is supposed to fix a problem with the allocator thrashing when using collectives from the NCCL backend of PyTorch.

I have tested performance with the command provided by @parthmannan in #420

torchrun --nproc_per_node=8 --nnodes=1 thunder/benchmarks/benchmark_litgpt.py --model_name Llama-2-13b-hf --compile thunder_cudnn --distributed_mode fsdp --shard_mode zero2 --bucketing_mode none --micro_batch_size 1 --global_batch_size 8

and this PR gives ~2.11x performance improvement (1517 ms -> 716 ms).

Fixes #420.
Fixes #477.

cc @parthmannan

thunder/core/trace.py Outdated Show resolved Hide resolved
@t-vi
Copy link
Collaborator

t-vi commented Jun 3, 2024

Does this work, though, from the PyTorch source it would seem that it needs to be set for the process group:

https://github.com/pytorch/pytorch/blob/f343f98710dfa7305a873f558086c595a3c3d3d4/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L776

@IvanYashchuk
Copy link
Collaborator Author

IvanYashchuk commented Jun 3, 2024

Does this work, though, from the PyTorch source it would seem that it needs to be set for the process group:

https://github.com/pytorch/pytorch/blob/f343f98710dfa7305a873f558086c595a3c3d3d4/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L776

Yes, getCvarBool parses the environment variable with std::getenv https://github.com/pytorch/pytorch/blob/f343f98710dfa7305a873f558086c595a3c3d3d4/torch/csrc/distributed/c10d/Utils.hpp#L160

@lantiga
Copy link
Collaborator

lantiga commented Jun 3, 2024

As a side note: in relation to this pytorch/pytorch#76861 (comment)

It would be interesting to understand what happens if we automatically insert wait_stream at the end of a region (or the whole model iteration), instead of disabling record streams entirely.

@IvanYashchuk
Copy link
Collaborator Author

As a side note: in relation to this pytorch/pytorch#76861 (comment)

It would be interesting to understand what happens if we automatically insert wait_stream at the end of a region (or the whole model iteration), instead of disabling record streams entirely.

Thunder doesn't know anything about CUDA streams today.

@lantiga
Copy link
Collaborator

lantiga commented Jun 3, 2024

As a side note: in relation to this pytorch/pytorch#76861 (comment)
It would be interesting to understand what happens if we automatically insert wait_stream at the end of a region (or the whole model iteration), instead of disabling record streams entirely.

Thunder doesn't know anything about CUDA streams today.

Fair point

@IvanYashchuk
Copy link
Collaborator Author

Setting the environment variable from the command line or in thunder/__init__.py results in consistently better performance suggesting that the current approach in this PR doesn't quite work

TORCH_NCCL_AVOID_RECORD_STREAMS=1 torchrun --nproc_per_node=8 --nnodes=1 thunder/benchmarks/benchmark_litgpt.py --model_name Llama-2-13b-hf --compile thunder_cudnn --distributed_mode fsdp --shard_mode zero2 --bucketing_mode none --micro_batch_size 1 --global_batch_size 8

Average iter time: 722.36 ms

@IvanYashchuk
Copy link
Collaborator Author

Setting the environment variable from the command line or in thunder/__init__.py results in consistently better performance suggesting that the current approach in this PR doesn't quite work

TORCH_NCCL_AVOID_RECORD_STREAMS=1 torchrun --nproc_per_node=8 --nnodes=1 thunder/benchmarks/benchmark_litgpt.py --model_name Llama-2-13b-hf --compile thunder_cudnn --distributed_mode fsdp --shard_mode zero2 --bucketing_mode none --micro_batch_size 1 --global_batch_size 8

Average iter time: 722.36 ms

Maybe the backward call is not affected by this decorator because it's a separate thread and setting the env variable in thunder/__init__.py works because then the side C++ thread to PyTorch's Autograd engine inherits the variable value.

@IvanYashchuk
Copy link
Collaborator Author

Does this work, though, from the PyTorch source it would seem that it needs to be set for the process group:

https://github.com/pytorch/pytorch/blob/f343f98710dfa7305a873f558086c595a3c3d3d4/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp#L776

Right, and this needs to be set at the process group creation meaning in usual scripts before the init_process_group(backend="nccl") call.

thunder/__init__.py Outdated Show resolved Hide resolved
@IvanYashchuk IvanYashchuk requested a review from crcrpar June 3, 2024 18:18
@parthmannan
Copy link
Collaborator

This would help fix a lot of performance issues we have been seeing with large models/large batch sizes where we see memory thrashing. Thanks for working on this Ivan 🚀

@t-vi t-vi enabled auto-merge (squash) June 4, 2024 12:53
@t-vi
Copy link
Collaborator

t-vi commented Jun 4, 2024

Seems like legit CI failures in distributed. This seems a common thing

  File "/__w/1/s/thunder/core/module.py", line 142, in no_sync
    _sync_grads(self)
  File "/__w/1/s/thunder/distributed/__init__.py", line 142, in _sync_grads
    with tdist.distributed_c10d._coalescing_manager(group=process_group, async_ops=True) as cm:
  File "/usr/lib/python3.10/contextlib.py", line 142, in __exit__
    next(self.gen)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 2079, in _coalescing_manager
    cm.append(work)  # type: ignore[possibly-undefined]
UnboundLocalError: local variable 'work' referenced before assignment

in these failures:

2024-06-04T13:07:07.0476853Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_ddp_with_no_sync_grad_accumulation_executor_torch_bucket_size_in_mb_0_dataset_size_1 - RuntimeError: Process 0 exited with error code 10 and exception:
2024-06-04T13:07:07.0491884Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_fsdp_with_no_sync_grad_accumulation_executor_torch_bucketing_block_zero2 - RuntimeError: Process 0 exited with error code 10 and exception:
2024-06-04T13:07:07.0499141Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_fsdp_with_no_sync_grad_accumulation_executor_nvfuser_bucketing_block_zero2 - RuntimeError: Process 1 exited with error code 10 and exception:
2024-06-04T13:07:07.0506343Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_ddp_with_no_sync_grad_accumulation_executor_nvfuser_bucket_size_in_mb_25_dataset_size_2 - RuntimeError: Process 0 exited with error code 10 and exception:
2024-06-04T13:07:07.0521375Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_fsdp_with_no_sync_grad_accumulation_executor_nvfuser_bucketing_block_zero3 - RuntimeError: Process 0 exited with error code 10 and exception:
2024-06-04T13:07:07.0528651Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_ddp_with_no_sync_grad_accumulation_executor_nvfuser_bucket_size_in_mb_25_dataset_size_1 - RuntimeError: Process 1 exited with error code 10 and exception:
2024-06-04T13:07:07.0536202Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_ddp_with_no_sync_grad_accumulation_executor_torch_bucket_size_in_mb_25_dataset_size_2 - RuntimeError: Process 1 exited with error code 10 and exception:
2024-06-04T13:07:07.0543807Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_fsdp_with_no_sync_grad_accumulation_executor_torch_bucketing_block_zero3 - RuntimeError: Process 0 exited with error code 10 and exception:
2024-06-04T13:07:07.0551037Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_ddp_with_no_sync_grad_accumulation_executor_nvfuser_bucket_size_in_mb_0_dataset_size_1 - RuntimeError: Process 1 exited with error code 10 and exception:
2024-06-04T13:07:07.0558736Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_ddp_with_no_sync_grad_accumulation_executor_nvfuser_bucket_size_in_mb_0_dataset_size_2 - RuntimeError: Process 1 exited with error code 10 and exception:
2024-06-04T13:07:07.0566267Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_ddp_with_no_sync_grad_accumulation_executor_torch_bucket_size_in_mb_25_dataset_size_1 - RuntimeError: Process 1 exited with error code 10 and exception:
2024-06-04T13:07:07.0573911Z FAILED thunder/tests/distributed/test_ddp.py::CompileDDPTest::test_ddp_with_no_sync_grad_accumulation_executor_torch_bucket_size_in_mb_0_dataset_size_2 - RuntimeError: Process 0 exited with error code 10 and exception:

@IvanYashchuk
Copy link
Collaborator Author

That's unfortunate that the error is coming from the PyTorch source code and is not reproducible with my build. I'll fix it.

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @IvanYashchuk

@t-vi t-vi merged commit 23da3c1 into Lightning-AI:main Jun 5, 2024
36 checks passed
@lantiga
Copy link
Collaborator

lantiga commented Jun 5, 2024

Awesome @IvanYashchuk !

@IvanYashchuk IvanYashchuk deleted the fix-420 branch June 5, 2024 15:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
4 participants