-
Notifications
You must be signed in to change notification settings - Fork 156
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible memory leak when using nn.DataParallel #23
Comments
Hi Turlan! Are you certain that you don't create more and more GL contexts as the training progresses? I believe Torch's nn.DataParallel respawns the worker threads between every epoch, and even though this should destroy the contexts and release GPU memory, perhaps something there doesn't work as it should. If you call I have only tested nn.DataParallel with custom loaders that avoid the epoch changes and worker thread restarts altogether. I haven't tested with stock nn.DataParallel, but I would expect things to work there too — there are no known bugs. Would you be able to provide a minimal repro? Also, are you on Linux or Windows? |
Hi s-laine, this is collaborator of Turlan. We are also testing nn.DataParallel with a simple custom loader (actually just repetitively sending the same data). There's definitely no epoch issue. Just tried
Interestingly, there's no error printed out before the crash. The program simply stopped at a certain iteration number far smaller that what I set, and as Turlan mentioned, the GPU memory kept increasing when the program was running. btw, we were testing on Linux. |
Based on the log it looks like RasterizeGLContexts are not the issue here. Two GL contexts are created and the internal buffers don't grow after the initial allocations. I would next try to check that references to old tensors aren't mistakenly kept around so that Torch cannot deallocate them. It wouldn't be a surprise that this goes differently between 1 and 2 GPUs, because with multiple GPUs you need to aggregate gradients and share them between GPUs, and maybe this leaves some references lying around. There is also the possibility that despite cleaning up stale references, there remain circular references between the old tensors so that the objects stay in memory. In this case, calling Python's garbage collector may be required to clean them up quickly enough to avoid running out of GPU memory. To see if that is the case, you can try calling |
Thank you Samuli for the prompt reply. Just tried calling
If I just use one GPU by setting one visible GPU id (but still using |
Very interesting! Torch reports constant memory usage which rules out any problems with stale tensor references, and points to something leaking memory on the OpenGL side. However, the ~1.7MB per iteration doesn't quite match any of the buffers allocated in the rasterization op, and I also have no idea why using two GPUs would leak memory if one GPU doesn't. Contrary to my first comment, I realized that I haven't actually tried using nn.DataParallel but only nn.DistributedDataParallel that spawns a separate process per GPU. This way each child process uses only one GPU and that may be why in my tests I haven't encountered this problem. Perhaps this is something you could also consider as a workaround? In any case, I would highly appreciate a minimal repro that could be used to root-cause the issue. |
Tried nn.DistributedDataParallel and there's no memory leak issue! We will use it for now as a workaround. Thanks Samuli. Regading nn.DataParallel, I'll try to make a minimal repro later to help identify the potential issue. |
Hi all, I was also interested in trying out nvdiffrast with Firstly, I tried to debug the issue and found out that it is limited to the forward pass of import torch
import torch.nn as nn
import nvdiffrast.torch as dr
import nvdiffrast
from tqdm import tqdm
nvdiffrast.torch.set_log_level(0)
class Rasterizer(nn.Module):
def __init__(self):
super().__init__()
self.ctx = {}
def forward(self, vertices, faces):
if vertices.device not in self.ctx:
self.ctx[vertices.device] = dr.RasterizeGLContext(output_db=False, device=vertices.device)
print('Created GL context for device', vertices.device)
ctx = self.ctx[vertices.device]
rast_out, _ = dr.rasterize(ctx, vertices, faces, resolution=(256, 256))
return rast_out
gpu_ids = [0, 1]
rasterizer = nn.DataParallel(Rasterizer(), gpu_ids)
bs = 2
nt = 1
vertices = torch.randn(bs, nt*3, 4).cuda()
faces = torch.arange(nt*3).view(-1, 3)
faces_rep = faces.repeat(len(gpu_ids), 1).int().cuda()
with torch.no_grad():
for i in tqdm(range(100000)):
rasterizer(vertices, faces_rep) In the code above, the GL context is lazily initialized in In particular, debugging the crash with
In my case, the crash happens in the forward pass of For verification, the following code runs just fine: bs = 16
nt = 512
vertices1 = torch.randn(bs, nt*3, 4).to('cuda:0')
faces1 = torch.arange(nt*3).view(-1, 3).int().to('cuda:0')
ctx1 = dr.RasterizeGLContext(output_db=False, device='cuda:0')
vertices2 = torch.randn(bs, nt*3, 4).to('cuda:1')
faces2 = torch.arange(nt*3).view(-1, 3).int().to('cuda:1')
ctx2 = dr.RasterizeGLContext(output_db=False, device='cuda:1')
with torch.no_grad():
for i in tqdm(range(1000000)):
rast_out1, _ = dr.rasterize(ctx1, vertices1, faces1, resolution=(1024, 1024))
rast_out2, _ = dr.rasterize(ctx2, vertices2, faces2, resolution=(1024, 1024))
torch.cuda.synchronize('cuda:0')
torch.cuda.synchronize('cuda:1') Here, everything is called from the main thread and GPU utilization is 100% for both GPUs, which is good (the calls are asynchronous). No memory leaks or crashes. However, being able to use In the end, I came up with the following workaround/hack: I use a rudimental thread pool for the calls to Sharing the code for completeness, but don't rely too much on it (I haven't tested edge cases): class Dispatcher:
def __init__(self, gpu_ids):
self.threads = {}
self.events = {}
self.funcs = {}
self.return_events = {}
self.return_values = {}
for gpu_id in gpu_ids:
device = torch.device(gpu_id)
self.events[device] = threading.Event()
self.return_events[device] = threading.Event()
self.threads[device] = threading.Thread(target=Dispatcher.worker, args=(self, device,), daemon=True)
self.threads[device].start()
@staticmethod
def worker(self, device):
ctx = dr.RasterizeGLContext(output_db=False, device=device)
while True:
self.events[device].wait()
assert device not in self.return_values
self.return_values[device] = self.funcs[device](ctx)
del self.funcs[device]
self.events[device].clear()
self.return_events[device].set()
def __call__(self, device, func):
assert device not in self.funcs
self.funcs[device] = func
self.events[device].set()
self.return_events[device].wait()
ret_val = self.return_values[device]
del self.return_values[device]
self.return_events[device].clear()
return ret_val
gpu_ids = [0, 1]
dispatcher = Dispatcher(gpu_ids)
class Rasterizer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, vertices, faces):
rast_out, _ = dispatcher(vertices.device, lambda ctx: dr.rasterize(ctx, vertices, faces, resolution=(256, 256)))
return rast_out
rasterizer = nn.DataParallel(Rasterizer(), gpu_ids)
bs = 1024*2
nt = 512
vertices = torch.randn(bs, nt*3, 4).cuda()
faces = torch.arange(nt*3).view(-1, 3)
faces_rep = faces.repeat(len(gpu_ids), 1).int().cuda()
with torch.no_grad():
for i in tqdm(range(100000)):
rast_out = rasterizer(vertices, faces_rep) In the example above, the contexts are created in the dedicated threads, but creating them in the main thread works fine as well. Looks like the real initialization is done during the first call to Hope this helps, and if the authors have some insight into this, I would like to hear your opinion! |
Hi @dariopavllo, big thanks for posting your analysis and insights here! I hadn't examined the internals of This means that the OpenGL contexts need to be constantly migrated between threads. This appears to be an expensive and error-prone operation, leading to low GPU utilization and memory leaks. My guess is that some driver-level buffers or data structures of the context, probably related to Cuda-OpenGL interoperability, are thread-specific and they aren't deallocated until the context is destroyed entirely. This would lead to the observed accumulation of crud. There is a function called It makes perfect sense that your workaround solves these issues, as each context is always used from a single, dedicated thread. This was the only usage pattern that I had in mind when developing and testing the code, as I didn't use Torch's |
This is now addressed in the documentation since v0.2.6, with a link to this issue for details. Closing. |
Hi, when I use your code to implement multi-gpu training with the provided rasterization, the gpu memory keeps increasing.
I first define a list of instances of the class RasterizeGLContext for each gpu in the init func of pytorch nn.Module class.
During forward, I choose the RasterizeGLContext instance according to the current device id. The gpu memory keeps increasing when I use gpus >= 2.
I don't know whether I wrongly use the code or there exists some bugs in your implementation. If possible, could you provide some sample codes for multi-gpu training? Thanks!
The text was updated successfully, but these errors were encountered: