Skip to content

Commit

Permalink
fix vLLM v0.4.1 (#283)
Browse files Browse the repository at this point in the history
* fix vLLM v0.4.1

* assert vLLM

* fixed TP=1

* fix

* fix

* fix

* fix TP=1

* add kwargs["worker_use_ray"] = True
  • Loading branch information
hijkzzz committed Apr 29, 2024
1 parent d2a3ec1 commit 9af18be
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 65 deletions.
2 changes: 1 addition & 1 deletion dockerfile/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ RUN DEBIAN_FRONTEND=noninteractive apt install -y tzdata

RUN apt-get -y install build-essential git python3-dev python3-pip libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev gdb
RUN pip uninstall xgboost transformer_engine flash_attn -y
RUN pip install vllm==0.3.2
RUN pip install vllm==0.4.1

COPY docker-entrypoint.sh .
RUN chmod a+x docker-entrypoint.sh
Expand Down
57 changes: 37 additions & 20 deletions openrlhf/trainer/ray/vllm_engine.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,67 @@
import os
from typing import List
from typing import Dict, List

import ray
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy

from openrlhf.utils.logging import init_logger

logger = init_logger(__name__)


@ray.remote
class LLMRayActor:
def __init__(self, *args, **kwargs):
import vllm

if vllm.__version__ < "0.2.7" or kwargs["tensor_parallel_size"] == 1:
from vllm.worker import worker
assert vllm.__version__ >= "0.4.1", "OpenRLHF only supports vLLM >= 0.4.1"

self.use_gpu_executor = kwargs["tensor_parallel_size"] == 1

# See https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py
if self.use_gpu_executor:
from openrlhf.trainer.ray.vllm_worker_wrap import WorkerWrap

worker.Worker = WorkerWrap
vllm.worker.worker.Worker = WorkerWrap
else:
# NOTE: In 0.2.7, vLLM made a major change to its architecture which move one worker into the driver process.
# Driver process will manually set CUDA_VISIBLE_DEVICES before worker init. To avoid importing torch before
# set CUDA_VISIBLE_DEVICES, we must defer monkey patch.
# For more detail, see: https://github.com/vllm-project/vllm/pull/2221
def _set_cuda_visible_devices(device_ids: List[int]):
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids))
# RayGPUExecutor
# See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5
kwargs["worker_use_ray"] = True

from vllm.worker import worker
from openrlhf.trainer.ray.vllm_worker_wrap import WorkerWrap
if vllm.__version__ > "0.4.1":
RayWorkerWrapperPath = vllm.executor.ray_utils
else:
RayWorkerWrapperPath = vllm.engine.ray_utils

worker.Worker = WorkerWrap
class RayWorkerWrapper(RayWorkerWrapperPath.RayWorkerWrapper):
def __init__(self, *args, **kwargs) -> None:
kwargs["worker_module_name"] = "openrlhf.trainer.ray.vllm_worker_wrap"
kwargs["worker_class_name"] = "WorkerWrap"
super().__init__(*args, **kwargs)

vllm.engine.llm_engine.set_cuda_visible_devices = _set_cuda_visible_devices
RayWorkerWrapperPath.RayWorkerWrapper = RayWorkerWrapper

if vllm.__version__ >= "0.4.1":
kwargs["worker_use_ray"] = True
self.llm = vllm.LLM(*args, **kwargs)

def generate(self, *args, **kwargs):
return self.llm.generate(*args, **kwargs)

def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name):
return self.llm.llm_engine._run_workers(
"init_process_group", master_address, master_port, rank_offset, world_size, group_name
)
if self.use_gpu_executor:
return self.llm.llm_engine.model_executor.driver_worker.init_process_group(
master_address, master_port, rank_offset, world_size, group_name
)
else:
return self.llm.llm_engine.model_executor._run_workers(
"init_process_group", master_address, master_port, rank_offset, world_size, group_name
)

def update_weight(self, name, dtype, shape, empty_cache=False):
return self.llm.llm_engine._run_workers("update_weight", name, dtype, shape, empty_cache)
if self.use_gpu_executor:
return self.llm.llm_engine.model_executor.driver_worker.update_weight(name, dtype, shape, empty_cache)
else:
return self.llm.llm_engine.model_executor._run_workers("update_weight", name, dtype, shape, empty_cache)


def create_vllm_engines(num_engines: int, tensor_parallel_size: int, pretrain: str, seed: int):
Expand Down
45 changes: 1 addition & 44 deletions openrlhf/trainer/ray/vllm_worker_wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import inspect

import torch
from vllm.model_executor.weight_utils import hf_model_weights_iterator
from vllm.worker.worker import Worker

from openrlhf.utils.distributed_util import init_process_group
Expand All @@ -11,46 +10,7 @@
logger = init_logger(__name__)


def _hf_model_weights_iterator_wrap(model_name_or_path, *args, **kwargs):
if isinstance(model_name_or_path, dict):
for name, param in model_name_or_path.items():
yield name, param
else:
yield from hf_model_weights_iterator(model_name_or_path, *args, **kwargs)


class WorkerWrap(Worker):
def __init__(self, *args, **kwargs):
# Monkey patch hf_model_weights_iterator to allow update single weight
import vllm

self.vllm_version = vllm.__version__

if vllm.__version__ < "0.2.5":
from vllm.model_executor.weight_utils import hf_model_weights_iterator

modules = inspect.getmembers(vllm.model_executor.models, inspect.ismodule)
for _, m in modules:
m.hf_model_weights_iterator = _hf_model_weights_iterator_wrap
elif vllm.__version__ < "0.4.1":
# NOTE: In 0.2.5, vLLM introduce lazy model loader
# https://github.com/vllm-project/vllm/pull/2044
from vllm.model_executor.models import _MODELS, ModelRegistry

load_model_cls = ModelRegistry.load_model_cls

def patched_load_model_cls(model_arch: str):
module_name, _ = _MODELS[model_arch]
module = importlib.import_module(f"vllm.model_executor.models.{module_name}")
module.hf_model_weights_iterator = _hf_model_weights_iterator_wrap
logger.info(f"Monkey patch hf_model_weights_iterator for module {module_name}")

return load_model_cls(model_arch)

ModelRegistry.load_model_cls = patched_load_model_cls

super().__init__(*args, **kwargs)

def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name):
"""Init torch process group for model weights update"""
assert torch.distributed.is_initialized(), f"default torch process group must be initialized"
Expand Down Expand Up @@ -79,10 +39,7 @@ def update_weight(self, name, dtype, shape, empty_cache=False):
weight = torch.empty(shape, dtype=dtype, device="cuda")
torch.distributed.broadcast(weight, 0, group=self._model_update_group)

if self.vllm_version < "0.4.1":
self.model_runner.model.load_weights(model_name_or_path={name: weight})
else:
self.model_runner.model.load_weights(weights=[(name, weight)])
self.model_runner.model.load_weights(weights=[(name, weight)])

del weight
# TODO: should we empty cache if all weights have updated?
Expand Down

0 comments on commit 9af18be

Please sign in to comment.