diff --git a/jax-inference-offloading/dockerfile/oss.dockerfile b/jax-inference-offloading/dockerfile/oss.dockerfile index 2555e3798..dbbac60b0 100644 --- a/jax-inference-offloading/dockerfile/oss.dockerfile +++ b/jax-inference-offloading/dockerfile/oss.dockerfile @@ -91,11 +91,12 @@ FROM mealkit AS final # Finalize installation RUN <<"EOF" bash -ex -o pipefail export PIP_INDEX_URL=https://download.pytorch.org/whl/cu129 -export PIP_EXTRA_INDEX_URL=https://pypi.org/simple +export PIP_EXTRA_INDEX_URL="https://flashinfer.ai/whl/cu129 https://pypi.org/simple" pushd /opt/pip-tools.d pip-compile -o requirements.txt $(ls requirements*.in) --constraint overrides.in # remove cuda wheels from install list since the container already has them sed -i 's/^nvidia-/# nvidia-/g' requirements.txt +sed -i 's/# nvidia-nvshmem/nvidia-nvshmem/g' requirements.txt pip install --no-deps --src /opt -r requirements.txt # make pip happy about the missing torch dependencies pip-mark-installed nvidia-cublas-cu12 nvidia-cuda-cupti-cu12 \ diff --git a/jax-inference-offloading/examples/requirements.in b/jax-inference-offloading/examples/requirements.in index 06b18129e..9dba5c073 100644 --- a/jax-inference-offloading/examples/requirements.in +++ b/jax-inference-offloading/examples/requirements.in @@ -1,4 +1,4 @@ -google-tunix==0.1.3 +google-tunix datasets tensorflow-datasets tensorflow-cpu; platform_machine == "x86_64" diff --git a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py index 59252ca02..1fb71727d 100644 --- a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py +++ b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py @@ -30,6 +30,7 @@ MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, + WEIGHT_LOADER_V2_SUPPORTED, ) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding @@ -51,13 +52,27 @@ def device_info(self): ) def set_sharding(self): - for _, module in self.model_runner.model.named_modules(): + # The vLLM V2 weight loader does not support loading pre-sharded weights + # for the parallel linear modules. + # Therefore, we need to force these modules to use the V1 weight loader + # Once V2 weight loader supports pre-sharded weights, we can remove this workaround. + + # Prevent unquantized linear modules from using V2 weight loader + if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: + WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") + logger.warning("Removed UnquantizedLinearMethod from WEIGHT_LOADER_V2_SUPPORTED.") + + for name, module in self.model_runner.model.named_modules(): if type(module) in [ RowParallelLinear, ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, ]: + logger.debug(f"Setting sharding for module: {name} of type {type(module)}") + # force to use the V1 weight_loader + module.weight.weight_loader = module.weight_loader + # instruct V1 loader to treat the incoming weight as pre-sharded module.weight.is_sharded_weight = True def get_tp_sharding_specs(self): diff --git a/jax-inference-offloading/setup.py b/jax-inference-offloading/setup.py index f60fe32e9..9bd9a89b8 100644 --- a/jax-inference-offloading/setup.py +++ b/jax-inference-offloading/setup.py @@ -66,7 +66,7 @@ def run(self): 'jax==0.8.0', 'jaxtyping', 'kagglehub', - 'vllm[flashinfer]==0.10.2', + 'vllm==0.11.2', ], cmdclass={ 'build_protos': BuildPackageProtos,