Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Platform] Refactor memory manage function in memory_profiling to Platform #13599

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/worker/test_profile.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# SPDX-License-Identifier: Apache-2.0

from typing import Union

import torch

from vllm.engine.arg_utils import EngineArgs
@@ -34,7 +36,7 @@ def test_gpu_memory_profiling():
)

# Set 10GiB as the total gpu ram to be device-agnostic
def mock_mem_info():
def mock_mem_info(device: Union[torch.types.Device, int] = None):
current_usage = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
mock_total_bytes = 10 * 1024**3
32 changes: 30 additions & 2 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
@@ -5,8 +5,8 @@

import os
from functools import lru_cache, wraps
from typing import (TYPE_CHECKING, Callable, List, Optional, Tuple, TypeVar,
Union)
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
TypeVar, Union)

import torch
from typing_extensions import ParamSpec
@@ -142,6 +142,34 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
if cache_config and cache_config.block_size is None:
cache_config.block_size = 16

@classmethod
def empty_cache(cls) -> None:
torch.cuda.empty_cache()

@classmethod
def reset_peak_memory_stats(cls,
device: Union[torch.types.Device,
int] = None) -> None:
torch.cuda.reset_peak_memory_stats(device)

@classmethod
def memory_stats(
cls,
device: Union[torch.types.Device, int] = None) -> Dict[str, Any]:
return torch.cuda.memory_stats(device)

@classmethod
def mem_get_info(
cls,
device: Union[torch.types.Device, int] = None) -> Tuple[int, int]:
return torch.cuda.mem_get_info(device)

@classmethod
def memory_reserved(cls,
device: Union[torch.types.Device, int] = None) -> int:
"""Return the memory reserved by the current device."""
return torch.cuda.memory_reserved(device)

@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
36 changes: 35 additions & 1 deletion vllm/platforms/interface.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@
import platform
import random
from platform import uname
from typing import TYPE_CHECKING, NamedTuple, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, NamedTuple, Optional, Tuple, Union

import numpy as np
import torch
@@ -306,6 +306,40 @@ def is_pin_memory_available(cls) -> bool:
return False
return True

@classmethod
def empty_cache(cls) -> None:
"""
Clear the cache of the current device.
"""
raise NotImplementedError

@classmethod
def reset_peak_memory_stats(cls,
device: Union[torch.types.Device,
int] = None) -> None:
"""Reset the peak memory stats of the current device."""
raise NotImplementedError

@classmethod
def memory_stats(
cls,
device: Union[torch.types.Device, int] = None) -> Dict[str, Any]:
"""Return the memory stats of the current device."""
raise NotImplementedError

@classmethod
def mem_get_info(
cls,
device: Union[torch.types.Device, int] = None) -> Tuple[int, int]:
"""Return the global free and total memory of the current device."""
raise NotImplementedError

@classmethod
def memory_reserved(cls,
device: Union[torch.types.Device, int] = None) -> int:
"""Return the memory reserved by the current device."""
raise NotImplementedError

@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
30 changes: 29 additions & 1 deletion vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

import os
from functools import lru_cache, wraps
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import torch

@@ -220,6 +220,34 @@ def verify_quantization(cls, quant: str) -> None:
def get_punica_wrapper(cls) -> str:
return "vllm.lora.punica_wrapper.punica_gpu.PunicaWrapperGPU"

@classmethod
def empty_cache(cls) -> None:
torch.cuda.empty_cache()

@classmethod
def reset_peak_memory_stats(cls,
device: Union[torch.types.Device,
int] = None) -> None:
torch.cuda.reset_peak_memory_stats(device)

@classmethod
def memory_stats(
cls,
device: Union[torch.types.Device, int] = None) -> Dict[str, Any]:
return torch.cuda.memory_stats(device)

@classmethod
def mem_get_info(
cls,
device: Union[torch.types.Device, int] = None) -> Tuple[int, int]:
return torch.cuda.mem_get_info(device)

@classmethod
def memory_reserved(cls,
device: Union[torch.types.Device, int] = None) -> int:
"""Return the memory reserved by the current device."""
return torch.cuda.memory_reserved(device)

@classmethod
def get_current_memory_usage(cls,
device: Optional[torch.types.Device] = None
22 changes: 12 additions & 10 deletions vllm/utils.py
Original file line number Diff line number Diff line change
@@ -1907,7 +1907,7 @@ def kill_process_tree(pid: int):
class MemorySnapshot:
"""Memory snapshot."""
torch_peak: int = 0
cuda_memory: int = 0
device_memory: int = 0
torch_memory: int = 0
non_torch_memory: int = 0
timestamp: float = 0.0
@@ -1923,24 +1923,25 @@ def measure(self):
# After `torch.cuda.reset_peak_memory_stats()`,
# `torch.cuda.memory_reserved()` will keep growing, and only shrink
# when we call `torch.cuda.empty_cache()` or OOM happens.
self.torch_peak = torch.cuda.memory_stats().get(
from vllm.platforms import current_platform
self.torch_peak = current_platform.memory_stats().get(
"allocated_bytes.all.peak", 0)

self.cuda_memory = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
self.device_memory = current_platform.mem_get_info(
)[1] - current_platform.mem_get_info()[0]

# torch.cuda.memory_reserved() is how many bytes
# PyTorch gets from cuda (by calling cudaMalloc, etc.)
# this is used to measure the non-torch memory usage
self.torch_memory = torch.cuda.memory_reserved()
self.torch_memory = current_platform.memory_reserved()

self.non_torch_memory = self.cuda_memory - self.torch_memory
self.non_torch_memory = self.device_memory - self.torch_memory
self.timestamp = time.time()

def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
return MemorySnapshot(
torch_peak=self.torch_peak - other.torch_peak,
cuda_memory=self.cuda_memory - other.cuda_memory,
device_memory=self.device_memory - other.device_memory,
torch_memory=self.torch_memory - other.torch_memory,
non_torch_memory=self.non_torch_memory - other.non_torch_memory,
timestamp=self.timestamp - other.timestamp,
@@ -2012,9 +2013,10 @@ def memory_profiling(

The increase of `non_torch_memory` from creating the current vLLM instance until after profiling to get (c.).
""" # noqa
from vllm.platforms import current_platform
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
current_platform.empty_cache()
current_platform.reset_peak_memory_stats()

result = MemoryProfilingResult()

@@ -2027,7 +2029,7 @@ def memory_profiling(
yield result

gc.collect()
torch.cuda.empty_cache()
current_platform.empty_cache()

result.after_profile.measure()

21 changes: 13 additions & 8 deletions vllm/worker/worker.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is only used by nvidia gpus. we don't need to hide torch.cuda in this file.

Original file line number Diff line number Diff line change
@@ -214,12 +214,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
from vllm.platforms import current_platform

# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
current_platform.empty_cache()
current_platform.reset_peak_memory_stats()

free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
free_memory_pre_profile, total_gpu_memory = (
current_platform.mem_get_info())

# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
@@ -271,14 +274,16 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
return num_gpu_blocks, num_cpu_blocks

def _assert_memory_footprint_increased_during_profiling(self):
from vllm.platforms import current_platform

# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
free_gpu_memory, total = torch.cuda.mem_get_info()
cuda_memory = total - free_gpu_memory
assert self.baseline_snapshot.cuda_memory < cuda_memory, (
free_gpu_memory, total = current_platform.mem_get_info()
device_memory = total - free_gpu_memory
assert self.baseline_snapshot.device_memory < device_memory, (
"Error in memory profiling. "
f"Initial used memory {self.baseline_snapshot.cuda_memory}, "
f"currently used memory {cuda_memory}. "
f"Initial used memory {self.baseline_snapshot.device_memory}, "
f"currently used memory {device_memory}. "
f"This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")