Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory Leak in 2.10.4 release #604

Open
xw285cornell opened this issue Dec 1, 2021 · 6 comments
Open

Memory Leak in 2.10.4 release #604

xw285cornell opened this issue Dec 1, 2021 · 6 comments

Comments

@xw285cornell
Copy link

Just want to seek some clarification for the memory leak issue that is fixed in the 2.11 release (in the release note). Can you give us some details about the leak? We're noticing two potential memory leak in NCCL 2.10 release, both seems to be related with communicator initialize + abort

  1. The leak seems to be related with NVB for HCM topology. NCCL_NVB_PRECONNECT=0 seems to fix the problem.
  2. The health check in pytorch [PG NCCL] Disable NCCL health check pytorch/pytorch#67668 that we have to disable. It creates a communicator and then abort to verify the healthiness of the host. But it's causing memory regression and we have to turn it off. This doesn't seem to be related with HCM (we observe this on NVSwitch hosts)
@xw285cornell
Copy link
Author

cc. @kwen2501 @sjeaugey

@AddyLaddy
Copy link
Collaborator

Yes the memory leak was due to the intermediate GPU IPC memory not being released on NVB based topologies during communicator destruction. It does not occur on NVSwitch based systems.
It always leaks the memory when the communicator is destroyed, whether by Abort or by Destroy

I have extensively tested Communicator Destroy and Abort calls, but so far have not detected any additional memory leaks.

Do you need to see the code diffs for the NVB memory leak fix?

@xw285cornell
Copy link
Author

@AddyLaddy sure the code diff will be great! Also is there a unit test or something testing destroy/abort that we can try out?

@AddyLaddy
Copy link
Collaborator

I think these are all the changes you need to fix the NVB memory leak:

index f5e9f565..ae9da9be 100644
--- a/src/bootstrap.cc
+++ b/src/bootstrap.cc
@@ -202,7 +202,7 @@ struct unexConn {
 struct remAllocState {
   int cudaDev;
   int listenFd;
-  int stop;
+  volatile int stop;
 };
 
 struct extState {
@@ -257,7 +257,7 @@ void* ncclRemoteMemAllocationService(void* args) {
   for (int s=0; s<MAX_SEGMENTS; s++) segments[s] = NULL;
   for (int s=0; s<MAX_SEGMENTS; s++) {
     pollfds[s].fd = -1;
-    pollfds[s].events = POLLHUP;
+    pollfds[s].events = POLLIN;
   }
   pollfds[MAX_SEGMENTS].fd = state->listenFd;
   pollfds[MAX_SEGMENTS].events = POLLIN;
@@ -285,7 +285,7 @@ void* ncclRemoteMemAllocationService(void* args) {
       }
     }
     for (int s=0; s<MAX_SEGMENTS; s++) {
-      if (pollfds[s].revents & POLLHUP) {
+      if (pollfds[s].revents & (POLLIN|POLLHUP)) {
         if (cudaFree(segments[s]) != cudaSuccess) {
           WARN("[Rem Allocator] cudaFree %p failed", segments[s]);
         }
diff --git a/src/transport/p2p.cc b/src/transport/p2p.cc
index 38ac57dc..5bd92b11 100644
--- a/src/transport/p2p.cc
+++ b/src/transport/p2p.cc
@@ -21,6 +21,7 @@ struct p2pSendResources {
   void* ipcPtr;
   int remoteId;
   int memRank;
+  void* remIpcPtr;
   void* bootstrap;
 };
 
@@ -29,6 +30,7 @@ struct p2pRecvResources {
   void* ipcPtr;
   int remoteId;
   int memRank;
+  void* remIpcPtr;
   void* bootstrap;
 };
 
@@ -252,7 +254,7 @@ static ncclResult_t p2pSendConnect(struct ncclComm* comm, struct ncclConnect* co
   struct ncclRecvMem* remDevMem;
   struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo;
 
-  NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->ipcPtr));
+  NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->remIpcPtr));
 
   int offset = 0;
   for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
@@ -276,7 +278,7 @@ ncclResult_t p2pRecvConnect(struct ncclComm* comm, struct ncclConnect* connectIn
   struct ncclSendMem* remDevMem;
   struct p2pConnectInfo* info = (struct p2pConnectInfo*)connectInfo;
 
-  NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->ipcPtr));
+  NCCLCHECK(p2pMap(comm->peerInfo+rank, comm->peerInfo+info->rank, info, (void**)&remDevMem, &resources->remIpcPtr));
 
   int offset = 0;
   for (int p=0; p<NCCL_NUM_PROTOCOLS; p++) {
@@ -298,6 +300,8 @@ ncclResult_t p2pSendFree(void* resources) {
   struct p2pSendResources* sendRes = (struct p2pSendResources*)resources;
   if (sendRes->ipcPtr)
     CUDACHECK(cudaIpcCloseMemHandle(sendRes->ipcPtr));
+  if (sendRes->remIpcPtr)
+    CUDACHECK(cudaIpcCloseMemHandle(sendRes->remIpcPtr));
   if (sendRes->remoteId != -1) {
     NCCLCHECK(bootstrapRemFree(sendRes->remoteId, sendRes->memRank, sendRes->bootstrap));
     sendRes->devMem = NULL;
@@ -311,6 +315,8 @@ ncclResult_t p2pRecvFree(void* resources) {
   struct p2pRecvResources* recvRes = (struct p2pRecvResources*)resources;
   if (recvRes->ipcPtr)
     CUDACHECK(cudaIpcCloseMemHandle(recvRes->ipcPtr));
+  if (recvRes->remIpcPtr)
+    CUDACHECK(cudaIpcCloseMemHandle(recvRes->remIpcPtr));
   if (recvRes->remoteId != -1) {
     NCCLCHECK(bootstrapRemFree(recvRes->remoteId, recvRes->memRank, recvRes->bootstrap));
     recvRes->devMem = NULL;

@kwen2501
Copy link
Contributor

kwen2501 commented Dec 2, 2021

Thanks @AddyLaddy !

I wrote a simple program that creates and destroys a NCCL communicator in a loop.

Here is an output with 2.11.4, run with 4 interconnected GPUs on a DGX1 like machine (so there should not be NVB leak):

========== After  0 (init, destroy) ==========
GPU 1: total 16160 MiB free 15853 MiB
GPU 0: total 16160 MiB free 15853 MiB
GPU 3: total 16160 MiB free 15853 MiB
GPU 2: total 16160 MiB free 15853 MiB
NCCL version 2.11.4+cuda11.0
========== After  1 (init, destroy) ==========
GPU 3: total 16160 MiB free 15803 MiB
GPU 2: total 16160 MiB free 15803 MiB
GPU 0: total 16160 MiB free 15803 MiB
GPU 1: total 16160 MiB free 15803 MiB
========== After  2 (init, destroy) ==========
GPU 3: total 16160 MiB free 15803 MiB
GPU 1: total 16160 MiB free 15803 MiB
GPU 2: total 16160 MiB free 15803 MiB
GPU 0: total 16160 MiB free 15803 MiB
========== After  3 (init, destroy) ==========
GPU 0: total 16160 MiB free 15803 MiB
GPU 2: total 16160 MiB free 15803 MiB
GPU 3: total 16160 MiB free 15803 MiB
GPU 1: total 16160 MiB free 15803 MiB
========== After  4 (init, destroy) ==========
GPU 1: total 16160 MiB free 15803 MiB
GPU 3: total 16160 MiB free 15803 MiB
GPU 2: total 16160 MiB free 15803 MiB
GPU 0: total 16160 MiB free 15803 MiB
========== After  5 (init, destroy) ==========
GPU 0: total 16160 MiB free 15803 MiB
GPU 1: total 16160 MiB free 15803 MiB
GPU 2: total 16160 MiB free 15803 MiB
GPU 3: total 16160 MiB free 15803 MiB
========== After  6 (init, destroy) ==========
GPU 3: total 16160 MiB free 15803 MiB
GPU 1: total 16160 MiB free 15803 MiB
GPU 0: total 16160 MiB free 15803 MiB
GPU 2: total 16160 MiB free 15803 MiB
========== After  7 (init, destroy) ==========
GPU 3: total 16160 MiB free 15803 MiB
GPU 1: total 16160 MiB free 15803 MiB
GPU 0: total 16160 MiB free 15803 MiB
GPU 2: total 16160 MiB free 15803 MiB
========== After  8 (init, destroy) ==========
GPU 3: total 16160 MiB free 15803 MiB
GPU 1: total 16160 MiB free 15803 MiB
GPU 0: total 16160 MiB free 15803 MiB
GPU 2: total 16160 MiB free 15803 MiB
========== After  9 (init, destroy) ==========
GPU 0: total 16160 MiB free 15803 MiB
GPU 2: total 16160 MiB free 15803 MiB
GPU 1: total 16160 MiB free 15803 MiB
GPU 3: total 16160 MiB free 15803 MiB

It looks like after the first iteration, there is 50 MiB not freed. But after that, the free memory becomes stable.

Another finding is that this amount varies from one NCCL version to another:
2.11.4: 50 MiB
2.10.3: 40 MiB
2.9.9: 30 MiB
2.8.4: 30 MiB

I still need to think about where the 50 MiB comes from, but it seems to be related to some static initialization, as it does not increase after the first iteration.

Appreciate your help in understanding this!

@AddyLaddy
Copy link
Collaborator

Yes there is some memory consumed by the CUDA runtime that is not released. Perhaps the Cuda malloc heap and other context info?

We have a test that does a similar checks. Due to the 'background' memory usage of the CUDA RT, I sample the CUDA memory stats after a warmup Communicator Alloc/Destroy, so that I am only looking for NCCL memory leaks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants