Skip to content

Commit

Permalink
Pipeline Parallel
Browse files Browse the repository at this point in the history
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
  • Loading branch information
andoorve committed Apr 27, 2024
1 parent eefeb16 commit 06609d9
Show file tree
Hide file tree
Showing 24 changed files with 450 additions and 172 deletions.
3 changes: 0 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,6 @@ def __init__(
self._verify_args()

def _verify_args(self) -> None:
if self.pipeline_parallel_size > 1:
raise NotImplementedError(
"Pipeline parallelism is not supported yet.")
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
Expand Down
3 changes: 3 additions & 0 deletions vllm/core/block_manager_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ def append_slots(
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
# NOTE: fork does not allocate a new physical block.
# Thus, it is always safe from OOM.
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.copy()
# When using a sliding window, blocks will be eventually reused.
Expand Down
3 changes: 3 additions & 0 deletions vllm/core/block_manager_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def get_common_computed_block_ids(
seq_block_ids)

def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
if parent_seq.seq_id not in self.block_tables:
# Parent sequence has either been freed or never existed.
return
src_block_table = self.block_tables[parent_seq.seq_id]
self.block_tables[child_seq.seq_id] = src_block_table.fork()

Expand Down
11 changes: 8 additions & 3 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from dataclasses import dataclass, field
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union

from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.config import (CacheConfig, LoRAConfig, ParallelConfig,
SchedulerConfig)
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.core.policy import Policy, PolicyFactory
from vllm.logger import init_logger
Expand Down Expand Up @@ -241,10 +242,12 @@ def __init__(
self,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
lora_config: Optional[LoRAConfig],
) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.parallel_config = parallel_config
# Note for LoRA scheduling: the current policy is extremely
# simple and NOT fair. It can lead to starvation of some
# LoRAs. This should be improved in the future.
Expand All @@ -264,8 +267,10 @@ def __init__(
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
block_size=self.cache_config.block_size,
num_gpu_blocks=self.cache_config.num_gpu_blocks,
num_cpu_blocks=self.cache_config.num_cpu_blocks,
num_gpu_blocks=self.cache_config.num_gpu_blocks //
self.parallel_config.pipeline_parallel_size,
num_cpu_blocks=self.cache_config.num_cpu_blocks //
self.parallel_config.pipeline_parallel_size,
sliding_window=self.cache_config.sliding_window,
enable_caching=self.cache_config.enable_prefix_caching)

Expand Down
33 changes: 29 additions & 4 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.distributed
from torch.distributed import ProcessGroup

from .parallel_state import (get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
from .parallel_state import (get_pipeline_model_parallel_group,
get_pipeline_model_parallel_next_rank,
get_pipeline_model_parallel_prev_rank,
get_tensor_model_parallel_group,
get_tensor_model_parallel_world_size,
is_pynccl_enabled_for_all_reduce)

Expand Down Expand Up @@ -87,7 +90,7 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if get_tensor_model_parallel_rank() == dst:
if torch.distributed.get_rank() == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
Expand All @@ -96,7 +99,7 @@ def tensor_model_parallel_gather(input_: torch.Tensor,
gather_list,
dst=dst,
group=get_tensor_model_parallel_group())
if get_tensor_model_parallel_rank() == dst:
if torch.distributed.get_rank() == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
Expand Down Expand Up @@ -209,3 +212,25 @@ def broadcast_tensor_dict(
for async_handle in async_handles:
async_handle.wait()
return tensor_dict


def send_next_rank(tensors: List[torch.Tensor]) -> None:
"""Send the tensors to the next pipeline model parallel rank."""
combined_tensor = torch.cat(tensors, dim=0)
torch.cat(tensors, dim=0)
torch.distributed.send(combined_tensor,
get_pipeline_model_parallel_next_rank(),
get_pipeline_model_parallel_group())


def recv_prev_rank(num_tensors: int, sizes: torch.Size, dtype: torch.dtype,
device: torch.device) -> List[torch.Tensor]:
sizes = list(sizes)
"""Receive tensors from the previous pipeline model parallel rank."""
combined_tensor = torch.empty([sizes[0] * num_tensors] + sizes[1:],
dtype=dtype,
device=device)
torch.distributed.recv(combined_tensor,
get_pipeline_model_parallel_prev_rank(),
get_pipeline_model_parallel_group())
return torch.chunk(combined_tensor, num_tensors, dim=0)
17 changes: 17 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,23 @@ def get_pipeline_model_parallel_prev_rank():
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]


def is_pipeline_model_parallel_first_rank() -> bool:
"""Return True if the caller is the first rank in the pipeline"""
return get_pipeline_model_parallel_rank() == 0


def is_pipeline_model_parallel_last_rank() -> bool:
"""Return True if the caller is the last rank in the pipeline"""
return get_pipeline_model_parallel_rank(
) == get_pipeline_model_parallel_world_size() - 1


def is_tensor_model_parallel_first_rank() -> bool:
"""Return True if the caller is the first rank in the tensor
parallel group"""
return get_tensor_model_parallel_rank() == 0


def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TENSOR_MODEL_PARALLEL_GROUP
Expand Down
46 changes: 36 additions & 10 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def has_new_requests(self):
class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""

async def step_async(self) -> List[RequestOutput]:
async def step_async(self, virtual_engine: int) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
Expand All @@ -206,14 +206,15 @@ async def step_async(self) -> List[RequestOutput]:
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()

if not scheduler_outputs.is_empty():
# Execute the model.
output = await self.model_executor.execute_model_async(
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
scheduler_outputs.blocks_to_swap_out,
scheduler_outputs.blocks_to_copy)
scheduler_outputs.blocks_to_copy, virtual_engine)
else:
output = []

Expand Down Expand Up @@ -428,15 +429,16 @@ def _init_engine(self, *args,
# order of the arguments.
cache_config = kwargs["cache_config"]
parallel_config = kwargs["parallel_config"]
if parallel_config.tensor_parallel_size == 1:
if (parallel_config.tensor_parallel_size == 1
and parallel_config.pipeline_parallel_size == 1):
num_gpus = cache_config.gpu_memory_utilization
else:
num_gpus = 1
engine_class = ray.remote(num_gpus=num_gpus)(
self._engine_class).remote
return engine_class(*args, **kwargs)

async def engine_step(self) -> bool:
async def engine_step(self, virtual_engine: int) -> bool:
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
Expand Down Expand Up @@ -467,7 +469,7 @@ async def engine_step(self) -> bool:
if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore
else:
request_outputs = await self.engine.step_async()
request_outputs = await self.engine.step_async(virtual_engine)

# Put the outputs into the corresponding streams.
for request_output in request_outputs:
Expand All @@ -483,18 +485,42 @@ async def _engine_abort(self, request_ids: Iterable[str]):
self.engine.abort_request(request_ids)

async def run_engine_loop(self):
has_requests_in_progress = False
has_requests_in_progress = [
False
] * self.engine.parallel_config.pipeline_parallel_size
while True:
if not has_requests_in_progress:
if not any(has_requests_in_progress):
logger.debug("Waiting for new requests...")
await self._request_tracker.wait_for_new_requests()
logger.debug("Got new requests!")
requests_in_progress = [
asyncio.create_task(
asyncio.wait_for(self.engine_step(ve),
ENGINE_ITERATION_TIMEOUT_S)) for ve in
range(self.engine.parallel_config.pipeline_parallel_size)
]
has_requests_in_progress = [
True
] * self.engine.parallel_config.pipeline_parallel_size

# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try:
has_requests_in_progress = await asyncio.wait_for(
self.engine_step(), ENGINE_ITERATION_TIMEOUT_S)
done, _ = await asyncio.wait(
requests_in_progress, return_when=asyncio.FIRST_COMPLETED)
for task in done:
result = task.result()
virtual_engine = requests_in_progress.index(task)
if result or self.engine.scheduler[
virtual_engine].has_unfinished_seqs():
requests_in_progress[
virtual_engine] = asyncio.create_task(
asyncio.wait_for(
self.engine_step(virtual_engine),
ENGINE_ITERATION_TIMEOUT_S))
has_requests_in_progress[virtual_engine] = True
else:
has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
Expand Down
49 changes: 35 additions & 14 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ def __init__(
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, "
"max_seq_len=%d, download_dir=%r, load_format=%s, "
"tensor_parallel_size=%d, disable_custom_all_reduce=%s"
"tensor_parallel_size=%d, pipeline_parallel_size=%d, "
"disable_custom_all_reduce=%s, "
"quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d)",
Expand All @@ -119,6 +120,7 @@ def __init__(
load_config.download_dir,
load_config.load_format,
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size,
parallel_config.disable_custom_all_reduce,
model_config.quantization,
model_config.enforce_eager,
Expand Down Expand Up @@ -211,7 +213,11 @@ def __init__(
# Create the scheduler.
# NOTE: the cache_config here have been updated with the numbers of
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = Scheduler(scheduler_config, cache_config, lora_config)
self.scheduler = [
Scheduler(scheduler_config, cache_config, parallel_config,
lora_config)
for _ in range(parallel_config.pipeline_parallel_size)
]

# Metric Logging.
if self.log_stats:
Expand Down Expand Up @@ -442,7 +448,12 @@ def add_request(
arrival_time, lora_request, multi_modal_data)

# Add the sequence group to the scheduler.
self.scheduler.add_seq_group(seq_group)
costs = [
scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler
]
min_cost_scheduler = self.scheduler[costs.index(min(costs))]
min_cost_scheduler.add_seq_group(seq_group)

def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a request(s) with the given ID.
Expand All @@ -461,19 +472,22 @@ def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:
>>> # abort the request
>>> engine.abort_request(request_id)
"""
self.scheduler.abort_seq_group(request_id)
for scheduler in self.scheduler:
scheduler.abort_seq_group(request_id)

def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config

def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return self.scheduler.get_num_unfinished_seq_groups()
return sum(scheduler.get_num_unfinished_seq_groups()
for scheduler in self.scheduler)

def has_unfinished_requests(self) -> bool:
"""Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs()
return any(scheduler.has_unfinished_seqs()
for scheduler in self.scheduler)

def _process_model_outputs(
self,
Expand Down Expand Up @@ -507,7 +521,8 @@ def _process_model_outputs(
self.output_processor.process_outputs(seq_group, outputs)

# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()
for scheduler in self.scheduler:
scheduler.free_finished_seq_groups()

# Create the outputs.
request_outputs: List[RequestOutput] = []
Expand Down Expand Up @@ -572,7 +587,8 @@ def step(self) -> List[RequestOutput]:
>>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
seq_group_metadata_list, scheduler_outputs = self.scheduler[
0].schedule()

if not scheduler_outputs.is_empty():
output = self.model_executor.execute_model(
Expand Down Expand Up @@ -616,20 +632,25 @@ def _get_stats(

# KV Cache Usage in %.
num_total_gpu = self.cache_config.num_gpu_blocks
num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks()
num_free_gpu = sum(scheduler.block_manager.get_num_free_gpu_blocks()
for scheduler in self.scheduler)
gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu)

num_total_cpu = self.cache_config.num_cpu_blocks
cpu_cache_usage = 0.
if num_total_cpu > 0:
num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks(
)
num_free_cpu = sum(
scheduler.block_manager.get_num_free_cpu_blocks()
for scheduler in self.scheduler)
cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu)

# Scheduler State
num_running = len(self.scheduler.running)
num_swapped = len(self.scheduler.swapped)
num_waiting = len(self.scheduler.waiting)
num_running = sum(
len(scheduler.running) for scheduler in self.scheduler)
num_swapped = sum(
len(scheduler.swapped) for scheduler in self.scheduler)
num_waiting = sum(
len(scheduler.waiting) for scheduler in self.scheduler)

# Iteration stats if we have scheduler output.
num_prompt_tokens = 0
Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/output_processor/multi_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,5 @@ def _process_seq_outputs(self, seq: Sequence,
break

if seq.is_finished():
self.scheduler.free_seq(seq)
for scheduler in self.scheduler:
scheduler.free_seq(seq)
Loading

0 comments on commit 06609d9

Please sign in to comment.