Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Device-agnostic Abstraction for V1 Architecture #12992

Closed
1 task done
liangfu opened this issue Feb 9, 2025 · 0 comments
Closed
1 task done

[RFC]: Device-agnostic Abstraction for V1 Architecture #12992

liangfu opened this issue Feb 9, 2025 · 0 comments
Labels

Comments

@liangfu
Copy link
Contributor

liangfu commented Feb 9, 2025

Motivation.

vLLM V1 Engine architecture (#8779) takes chunked-prefill and prefix-caching as first-class feature, and simplifies multi-step scheduling via async process. The effort in trying to extend device support (e.g. #12480) brings the challenges in reusing existing code structure.

Purpose

This RFC is intent to discuss how to extend device support for V1 architecture, with explicit assumption that the device backend is able to support both chunked-prefill and prefix caching.

Goal:

  • Safely assume the device backend support both chunked-prefill and prefix-caching
  • Simplify the device-agnostic design and encourage code-reuse
    • GPUModelRunner class
    • Encourage code-reuse among devices that support openxla backend in torch.compile
    • Abstract out device-incompatibility code in helper function, instead of adding if-else conditions for handling device-specific optimizations (or bugs).
  • Keep attention backend structure similar to V0

Non-goal:

Proposed Change.

Move most parts of GPUModelRunner to ModelRunnerBase. These will include:

  • Most parts of the constructor __init__() are device-agnostic, except for caching device properties. We plan to build _cache_device_properties() function, in order to abstract this part specifically for each device backend.
  • _update_states() , _calc_mrope_positions(), _prepare_sampling(), _execute_encoder(), _gather_encoder_outputs(), get_model(), _dummy_run(), initialize_kv_cache() , get_kv_cache_spec() should be reused entirely without change
  • _prepare_inputs() is a bit tricky, since FlashAttentionBackend and FlashAttentionMetadata has been used. As long as we can maintain a consistent interface among the flash-attention backends, there should be no change at all. The challenge is that if we wanted to change the interface for GPU, every other backend might need to be changed as well. However, this is the most important function that we would like to reuse among multiple device backends, in order to centralize the input data processing steps to support the combination of a variety of features (e.g. chunked-prefill, prefix-caching, speculative-decoding and more in the future).
  • execute_model() pads input for piecewise CUDA graphs, when piecewise graph is enabled. Since we can always fallback to eager execution when piecewise graph is disabled, use_piecewise_graph flag would be False by default for the base class.

Keep GPU-specific optimization in GPUModelRunner

  • Ideally, we will only need to keep four functions for the moment. These will include _cache_device_properties(), load_model() , profiler_run(), capture_model() function. (See Appendix for details)
  • The _cache_device_properties() function is called in ModelRunnerBase.__init__() function, in order to avoid execution-time overhead in getting access to these properties.
  • load_model() uses DeviceMemoryProfiler, which is currently GPU-specific. We plan to duplicate this function as a short-term solution, and abstract DeviceMemoryProfiler() to be device-agnostic to avoid duplication.
  • Although most part of profile_run() is device-agnostic (except for torch.cuda.synchronize() for profiling with lora), the profiler_run() is only called in gpu_worker to determine size of the available memory. We plan to move this to GPU-specific ModelRunner, since reusing the function does not provide much value and may slow down GPU development.
  • capture_model() uses graph_capture() to trigger CUDA graph capture for specific shapes. The graph_capture() function is currently device-specific.

Repurpose use_cuda_graph (in ModelRunnerBase) → use_piecewise_graph

  • rename cudagraph_batch_sizes attribute to tokenwise_graph_batch_sizes , since cudagraph is used for token-wise operators.
  • capture_model function would be left unimplemented in the base class, since this is mostly device-specific

Appendix

An exmaple of revised GPUModelRunner class that derives from the proposed ModelRunnerBase

class GPUModelRunner(ModelRunnerBase):

    def _cache_device_properties(self):
        self.use_piecewise_graph = (self.vllm_config.compilation_config.level
                               == CompilationLevel.PIECEWISE
                               and not self.model_config.enforce_eager)
        # TODO(woosuk): Provide an option to tune the max cudagraph batch size.
        # The convention is different.
        # self.cudagraph_batch_sizes sorts in ascending order.
        # The batch sizes in the config are in descending order.
        self.tokenwise_graph_batch_sizes = list(
            reversed(
                self.vllm_config.compilation_config.cudagraph_capture_sizes))

        # Cache the device properties.
        self.device_properties = torch.cuda.get_device_properties(self.device)
        self.num_sms = self.device_properties.multi_processor_count

    def capture_model(self) -> None:
        if not self.use_piecewise_graph:
            logger.warning(
                "Skipping CUDA graph capture. Please add "
                "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
            return

        start_time = time.perf_counter()
        start_free_gpu_memory = torch.cuda.mem_get_info()[0]

        # Trigger CUDA graph capture for specific shapes.
        # Capture the large shapes first so that the smaller shapes
        # can reuse the memory pool allocated for the large shapes.
        with graph_capture(device=self.device):
            for num_tokens in reversed(self.tokenwise_graph_batch_sizes):
                for _ in range(self.vllm_config.compilation_config.
                               cudagraph_num_of_warmups):
                    self._dummy_run(num_tokens)
                self._dummy_run(num_tokens)

        end_time = time.perf_counter()
        end_free_gpu_memory = torch.cuda.mem_get_info()[0]
        elapsed_time = end_time - start_time
        cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
        # This usually takes 5~20 seconds.
        logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
                    elapsed_time, cuda_graph_size / (1 << 30))

    def load_model(self) -> None:
        logger.info("Starting to load model %s...", self.model_config.model)
        with DeviceMemoryProfiler() as m:  # noqa: SIM117
            self.model = get_model(vllm_config=self.vllm_config)
            if self.lora_config:
                self.model = self.load_lora_model(self.model,
                                                  self.model_config,
                                                  self.scheduler_config,
                                                  self.lora_config,
                                                  self.device)

        self.model_memory_usage = m.consumed_memory
        logger.info("Loading model weights took %.4f GB",
                    self.model_memory_usage / float(2**30))

Feedback Period.

No response

CC List.

No response

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@liangfu liangfu added the RFC label Feb 9, 2025
@liangfu liangfu changed the title [RFC]: Device-agnostic Refactoring for V1 Architecture [RFC]: Device-agnostic Abstraction for V1 Architecture Feb 9, 2025
@liangfu liangfu closed this as not planned Won't fix, can't repro, duplicate, stale Feb 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant