-
Notifications
You must be signed in to change notification settings - Fork 746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is there any indicator or flag indicating the algorithm of a communication is using? #754
Comments
@sjeaugey Could you please help answer this question? |
Short answer is no, there is no such log. You can uncomment this line to see it: As for why, it's because we can't print a line on every call (that would have a perf impact) and we can't also deduce which algorithm we're using easily. The internal tuning model will switch from one to the other depending on the number of nodes, number of GPUs per node, and size of the operation, but also the intra-node and inter-node bandwidth. NCCL will move from one combination to another depending on the size, for example it could do: tree/LL -> ring/LL -> tree/LL128 -> ring/LL128 -> tree/simple -> ring/simple. Or any other combination. So there is no easy rule to determine which algorithm/protocol is used. |
Hi,
thanks for your response. According to my survey, in the tuner file, the algorithm is determined. Is it possible to output the selected algorithm by adding a line of code in this file? I am not familiar with nccl code, so I am not sure about that.
…---Original---
From: "Sylvain ***@***.***>
Date: Mon, Dec 12, 2022 17:33 PM
To: ***@***.***>;
Cc: ***@***.******@***.***>;
Subject: Re: [NVIDIA/nccl] Is there any indicator or flag indicating thealgorithm of a communication is using? (Issue #754)
Short answer is no, there is no such log. You can uncomment this line to see it:
https://github.com/NVIDIA/nccl/blob/master/src/enqueue.cc#L1158
As for why, it's because we can't print a line on every call (that would have a perf impact) and we can't also deduce which algorithm we're using easily.
There are 4 algorithms now (Ring, Tree, CollnetChain and CollnetDirect) × 3 protocols (LL, LL128, Simple). Restricting to Ring and Tree, that's still 6 combinations.
The internal tuning model will switch from one to the other depending on the number of nodes, number of GPUs per node, and size of the operation, but also the intra-node and inter-node bandwidth. NCCL will move from one combination to another depending on the size, for example it could do: tree/LL -> ring/LL -> tree/LL128 -> ring/LL128 -> tree/simple -> ring/simple. Or any other combination. So there is no easy rule to determine which algorithm/protocol is used.
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you authored the thread.Message ID: ***@***.***>
|
I'm not sure what you mean. Yes, for each size, the algorithm/protocol is deterministic. But it's still dependent on the size. |
Hi Sylvain Jeaugey, Sorry for my unclear expression. I will try to express myself more clearly about my thought. In the code https://github.com/NVIDIA/nccl/blob/master/src/graph/tuning.cc, the algorithm (ring, tree and etc.) is somehow determined and never changed after executing this initiation according to the environment and data. Is it convenient for me to add some logs in this file to determine what algorithm is selected, given the environment and data? To my knowledge, the algorithm in this file is one-shot and may not incur much performance degradation. The reason why I am curious about the selected algorithm is that I want to build up a prediction model, to forecast the mini-batch time spent on multi-GPU multi-node training for my research idea. Sometimes the overhead is linear with the number of workers/nodes and sometimes it's not. I guess this is susceptible to algorithm selection. So awareness of the algorithm may be important for my prediction. If you found I go astray, could you please give me some idea? Thanks a lot. |
Sure, you could add traces and extracting the settings. You can also set NCCL_DEBUG=INFO NCCL_DEBUG_SUBSYS=TUNING to get the numbers, and then run ncclTopoGetAlgoTime() on the size you care about. |
Hi, I download the source code from GitHub and uncomment the line https://github.com/NVIDIA/nccl/blob/master/src/enqueue.cc#L1178, to check the algorithm it uses. Then I build and reinstall the NCCL according to https://github.com/NVIDIA/nccl/blob/master/README.md. However, it didn't output any log as expected. Should I clean the original NCCL in my host first? |
You may want to modify |
Hi Sylvain Jeaugey, thanks for your quick reply. According to your comments, I found my current version is not labeled as
and then:
I didn't encounter any error thrown from the console. Am I in the wrong way of customizing the NCCL code? I am fresh to a program written in C so maybe this is a silly problem. Thanks for your reply! |
Did you install the deb packages with Now you may just want to try that version without installing it. In that case you can just set:
And then run again. |
Hi, Since I test the NCCL under PyTorch, I found that maybe I should rebuild PyTorch with a customized NCCL path as discussed in https://discuss.pytorch.org/t/how-can-i-change-nccl-version-in-pytorch/143942. I will try to rebuild them to see if it will work or not. Anyway, thank you for your advice. |
Hi, I want to know if NCCL always picks the algorithm yielding the minimum value in the last row when using the AllReduce? In this case, is
|
The first number is the latency and the second is the bandwidth. Anything with 0 bandwidth will be dismissed (infinite time). LL has low latency but low bandwidth. LL128 more latency more bandwidth, Simple even more latency and more bandwidth... The formula is not always just latency + size/bandwidth, for Tree we have a correction factor to apply (it's ugly but that's what we have at the moment). |
Thanks! |
Hi,
I know there are three types of algorithms that NCCL could use to communicate among ranks under the all-reduce scenario (ring, tree, and another), depending on the environment setting of the cluster. But in this case, I cannot know the exact algorithm is using. Is there any indicator or flags that can tell user which algorithm is using in the log?
The text was updated successfully, but these errors were encountered: