Skip to content

Commit

Permalink
attempt to fix memory monitor with multiple CUDA devices
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Mar 12, 2023
1 parent 6033de1 commit a00cd8b
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions modules/memmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,16 @@ def __init__(self, name, device, opts):
self.data = defaultdict(int)

try:
torch.cuda.mem_get_info()
self.cuda_mem_get_info()
torch.cuda.memory_stats(self.device)
except Exception as e: # AMD or whatever
print(f"Warning: caught exception '{e}', memory monitor disabled")
self.disabled = True

def cuda_mem_get_info(self):
index = self.device.index if self.device.index is not None else torch.cuda.current_device()
return torch.cuda.mem_get_info(index)

def run(self):
if self.disabled:
return
Expand All @@ -43,10 +47,10 @@ def run(self):
self.run_flag.clear()
continue

self.data["min_free"] = torch.cuda.mem_get_info()[0]
self.data["min_free"] = self.cuda_mem_get_info()[0]

while self.run_flag.is_set():
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
free, total = self.cuda_mem_get_info()
self.data["min_free"] = min(self.data["min_free"], free)

time.sleep(1 / self.opts.memmon_poll_rate)
Expand All @@ -70,7 +74,7 @@ def monitor(self):

def read(self):
if not self.disabled:
free, total = torch.cuda.mem_get_info()
free, total = self.cuda_mem_get_info()
self.data["free"] = free
self.data["total"] = total

Expand Down

0 comments on commit a00cd8b

Please sign in to comment.