From 000099784e7f0c9423bc7dd4ebdc388fa0cca4ee Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Tue, 25 Nov 2025 08:15:41 +0000 Subject: [PATCH 1/7] patch vLLM weight loader --- jax-inference-offloading/dockerfile/oss.dockerfile | 3 ++- .../jax_inference_offloading/vllm/extension.py | 13 ++++++++++++- jax-inference-offloading/setup.py | 4 ++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/jax-inference-offloading/dockerfile/oss.dockerfile b/jax-inference-offloading/dockerfile/oss.dockerfile index 2555e3798..dbbac60b0 100644 --- a/jax-inference-offloading/dockerfile/oss.dockerfile +++ b/jax-inference-offloading/dockerfile/oss.dockerfile @@ -91,11 +91,12 @@ FROM mealkit AS final # Finalize installation RUN <<"EOF" bash -ex -o pipefail export PIP_INDEX_URL=https://download.pytorch.org/whl/cu129 -export PIP_EXTRA_INDEX_URL=https://pypi.org/simple +export PIP_EXTRA_INDEX_URL="https://flashinfer.ai/whl/cu129 https://pypi.org/simple" pushd /opt/pip-tools.d pip-compile -o requirements.txt $(ls requirements*.in) --constraint overrides.in # remove cuda wheels from install list since the container already has them sed -i 's/^nvidia-/# nvidia-/g' requirements.txt +sed -i 's/# nvidia-nvshmem/nvidia-nvshmem/g' requirements.txt pip install --no-deps --src /opt -r requirements.txt # make pip happy about the missing torch dependencies pip-mark-installed nvidia-cublas-cu12 nvidia-cuda-cupti-cu12 \ diff --git a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py index 59252ca02..4a4a6ae8e 100644 --- a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py +++ b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py @@ -30,6 +30,7 @@ MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear, + WEIGHT_LOADER_V2_SUPPORTED, ) from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding @@ -51,14 +52,24 @@ def device_info(self): ) def set_sharding(self): - for _, module in self.model_runner.model.named_modules(): + + try: + if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: + WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") + logger.warning("Removed UnquantizedLinearMethod from WEIGHT_LOADER_V2_SUPPORTED.") + except Exception as e: + logger.warning(f"Unable to adjust WEIGHT_LOADER_V2_SUPPORTED: {e}") + + for name, module in self.model_runner.model.named_modules(): if type(module) in [ RowParallelLinear, ColumnParallelLinear, MergedColumnParallelLinear, QKVParallelLinear, ]: + logger.debug(f"Setting sharding for module: {name} of type {type(module)}") module.weight.is_sharded_weight = True + module.weight.weight_loader = module.weight_loader def get_tp_sharding_specs(self): sharding_specs = {} diff --git a/jax-inference-offloading/setup.py b/jax-inference-offloading/setup.py index f60fe32e9..3963e4c71 100644 --- a/jax-inference-offloading/setup.py +++ b/jax-inference-offloading/setup.py @@ -63,10 +63,10 @@ def run(self): 'grpcio==1.76.*', 'protobuf==6.33.*', 'huggingface-hub', - 'jax==0.8.0', + 'jax==0.8.1', 'jaxtyping', 'kagglehub', - 'vllm[flashinfer]==0.10.2', + 'vllm==0.11.2', ], cmdclass={ 'build_protos': BuildPackageProtos, From 82fe0ce6812df8c01b1ac2739984f8995c5e5178 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Tue, 25 Nov 2025 08:21:33 +0000 Subject: [PATCH 2/7] bump tunix version --- jax-inference-offloading/examples/requirements.in | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax-inference-offloading/examples/requirements.in b/jax-inference-offloading/examples/requirements.in index 06b18129e..2865834d3 100644 --- a/jax-inference-offloading/examples/requirements.in +++ b/jax-inference-offloading/examples/requirements.in @@ -1,7 +1,7 @@ -google-tunix==0.1.3 +google-tunix==0.1.6 datasets tensorflow-datasets tensorflow-cpu; platform_machine == "x86_64" -tensorflow-aarch64; platform_machine == "aarch64" +tensorflow; platform_machine == "aarch64" grain ray From 5591ce449ec20da26b339ae83609f3ec16cb6c45 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Wed, 26 Nov 2025 05:10:30 +0000 Subject: [PATCH 3/7] address PR comments --- .../examples/example-transfer.sh | 1 - jax-inference-offloading/examples/rollout.py | 2 +- .../jax_inference_offloading/vllm/extension.py | 17 ++++++++++------- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/jax-inference-offloading/examples/example-transfer.sh b/jax-inference-offloading/examples/example-transfer.sh index 911cbce8f..d8971cb3f 100755 --- a/jax-inference-offloading/examples/example-transfer.sh +++ b/jax-inference-offloading/examples/example-transfer.sh @@ -245,7 +245,6 @@ python "${DIR}/../jax_inference_offloading/controller/gateway.py" 2>&1 | tee ${O PIDS+=($!) CUDA_VISIBLE_DEVICES=$(IFS=','; echo "${VLLM_GPU_ARRAY[*]}") \ -MODEL_NAME=${MODEL_PATH:-$MODEL_NAME} \ python "${DIR}/rollout.py" 2>&1 | tee ${OUTPUT_DIR}/rollout.log & PIDS+=($!) diff --git a/jax-inference-offloading/examples/rollout.py b/jax-inference-offloading/examples/rollout.py index 61e91df50..e1276071a 100644 --- a/jax-inference-offloading/examples/rollout.py +++ b/jax-inference-offloading/examples/rollout.py @@ -39,7 +39,7 @@ def main(): load_format = os.environ.get("VLLM_LOAD_FORMAT", "dummy") model_name = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") model_path = os.environ.get("MODEL_PATH", None) - model = model_path or model_name + model = model_name or model_path logging.basicConfig(level=logging.INFO) diff --git a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py index 4a4a6ae8e..ab6fe07d3 100644 --- a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py +++ b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py @@ -52,13 +52,14 @@ def device_info(self): ) def set_sharding(self): + # The vLLM V2 weight loader does not support loading pre-sharded weights + # for the parallel linear modules. + # Therefore, we need to force these modules to use the V1 weight loader - try: - if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: - WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") - logger.warning("Removed UnquantizedLinearMethod from WEIGHT_LOADER_V2_SUPPORTED.") - except Exception as e: - logger.warning(f"Unable to adjust WEIGHT_LOADER_V2_SUPPORTED: {e}") + # Prevent unquantized linear modules from using V2 weight loader + if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: + WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod") + logger.warning("Removed UnquantizedLinearMethod from WEIGHT_LOADER_V2_SUPPORTED.") for name, module in self.model_runner.model.named_modules(): if type(module) in [ @@ -68,8 +69,10 @@ def set_sharding(self): QKVParallelLinear, ]: logger.debug(f"Setting sharding for module: {name} of type {type(module)}") - module.weight.is_sharded_weight = True + # force to use the V1 weight_loader module.weight.weight_loader = module.weight_loader + # instruct V1 loader to treat the incoming weight as pre-sharded + module.weight.is_sharded_weight = True def get_tp_sharding_specs(self): sharding_specs = {} From d45fa3a68a32b4cfaae9cb6299aa2474534d4553 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Wed, 26 Nov 2025 05:13:20 +0000 Subject: [PATCH 4/7] address PR comments --- .../jax_inference_offloading/vllm/extension.py | 1 + 1 file changed, 1 insertion(+) diff --git a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py index ab6fe07d3..1fb71727d 100644 --- a/jax-inference-offloading/jax_inference_offloading/vllm/extension.py +++ b/jax-inference-offloading/jax_inference_offloading/vllm/extension.py @@ -55,6 +55,7 @@ def set_sharding(self): # The vLLM V2 weight loader does not support loading pre-sharded weights # for the parallel linear modules. # Therefore, we need to force these modules to use the V1 weight loader + # Once V2 weight loader supports pre-sharded weights, we can remove this workaround. # Prevent unquantized linear modules from using V2 weight loader if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED: From a081257236878dadc70682510e8be9299131b386 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Fri, 28 Nov 2025 08:21:44 +0000 Subject: [PATCH 5/7] revert change related to a vllm tokenizer load failure that is no longer reproducible --- jax-inference-offloading/examples/example-transfer.sh | 1 + jax-inference-offloading/examples/rollout.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/jax-inference-offloading/examples/example-transfer.sh b/jax-inference-offloading/examples/example-transfer.sh index d8971cb3f..911cbce8f 100755 --- a/jax-inference-offloading/examples/example-transfer.sh +++ b/jax-inference-offloading/examples/example-transfer.sh @@ -245,6 +245,7 @@ python "${DIR}/../jax_inference_offloading/controller/gateway.py" 2>&1 | tee ${O PIDS+=($!) CUDA_VISIBLE_DEVICES=$(IFS=','; echo "${VLLM_GPU_ARRAY[*]}") \ +MODEL_NAME=${MODEL_PATH:-$MODEL_NAME} \ python "${DIR}/rollout.py" 2>&1 | tee ${OUTPUT_DIR}/rollout.log & PIDS+=($!) diff --git a/jax-inference-offloading/examples/rollout.py b/jax-inference-offloading/examples/rollout.py index e1276071a..61e91df50 100644 --- a/jax-inference-offloading/examples/rollout.py +++ b/jax-inference-offloading/examples/rollout.py @@ -39,7 +39,7 @@ def main(): load_format = os.environ.get("VLLM_LOAD_FORMAT", "dummy") model_name = os.environ.get("MODEL_NAME", "meta-llama/Llama-3.1-8B-Instruct") model_path = os.environ.get("MODEL_PATH", None) - model = model_name or model_path + model = model_path or model_name logging.basicConfig(level=logging.INFO) From 8b3440f06245051d6402f96f5d12b58277444bfa Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Fri, 28 Nov 2025 08:28:37 +0000 Subject: [PATCH 6/7] remove tunix version pin --- jax-inference-offloading/examples/requirements.in | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax-inference-offloading/examples/requirements.in b/jax-inference-offloading/examples/requirements.in index 2865834d3..9dba5c073 100644 --- a/jax-inference-offloading/examples/requirements.in +++ b/jax-inference-offloading/examples/requirements.in @@ -1,7 +1,7 @@ -google-tunix==0.1.6 +google-tunix datasets tensorflow-datasets tensorflow-cpu; platform_machine == "x86_64" -tensorflow; platform_machine == "aarch64" +tensorflow-aarch64; platform_machine == "aarch64" grain ray From 43729847059c0f7fdc8bc2f81faff49a4f3356f8 Mon Sep 17 00:00:00 2001 From: Yu-Hang Tang Date: Fri, 28 Nov 2025 08:30:42 +0000 Subject: [PATCH 7/7] revert JAX version bump --- jax-inference-offloading/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax-inference-offloading/setup.py b/jax-inference-offloading/setup.py index 3963e4c71..9bd9a89b8 100644 --- a/jax-inference-offloading/setup.py +++ b/jax-inference-offloading/setup.py @@ -63,7 +63,7 @@ def run(self): 'grpcio==1.76.*', 'protobuf==6.33.*', 'huggingface-hub', - 'jax==0.8.1', + 'jax==0.8.0', 'jaxtyping', 'kagglehub', 'vllm==0.11.2',