Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion jax-inference-offloading/dockerfile/oss.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,12 @@ FROM mealkit AS final
# Finalize installation
RUN <<"EOF" bash -ex -o pipefail
export PIP_INDEX_URL=https://download.pytorch.org/whl/cu129
export PIP_EXTRA_INDEX_URL=https://pypi.org/simple
export PIP_EXTRA_INDEX_URL="https://flashinfer.ai/whl/cu129 https://pypi.org/simple"
pushd /opt/pip-tools.d
pip-compile -o requirements.txt $(ls requirements*.in) --constraint overrides.in
# remove cuda wheels from install list since the container already has them
sed -i 's/^nvidia-/# nvidia-/g' requirements.txt
sed -i 's/# nvidia-nvshmem/nvidia-nvshmem/g' requirements.txt
pip install --no-deps --src /opt -r requirements.txt
# make pip happy about the missing torch dependencies
pip-mark-installed nvidia-cublas-cu12 nvidia-cuda-cupti-cu12 \
Expand Down
2 changes: 1 addition & 1 deletion jax-inference-offloading/examples/requirements.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
google-tunix==0.1.3
google-tunix
datasets
tensorflow-datasets
tensorflow-cpu; platform_machine == "x86_64"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
WEIGHT_LOADER_V2_SUPPORTED,
)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding

Expand All @@ -51,13 +52,27 @@ def device_info(self):
)

def set_sharding(self):
for _, module in self.model_runner.model.named_modules():
# The vLLM V2 weight loader does not support loading pre-sharded weights
# for the parallel linear modules.
# Therefore, we need to force these modules to use the V1 weight loader
# Once V2 weight loader supports pre-sharded weights, we can remove this workaround.

# Prevent unquantized linear modules from using V2 weight loader
if "UnquantizedLinearMethod" in WEIGHT_LOADER_V2_SUPPORTED:
WEIGHT_LOADER_V2_SUPPORTED.remove("UnquantizedLinearMethod")
logger.warning("Removed UnquantizedLinearMethod from WEIGHT_LOADER_V2_SUPPORTED.")

for name, module in self.model_runner.model.named_modules():
if type(module) in [
RowParallelLinear,
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
]:
logger.debug(f"Setting sharding for module: {name} of type {type(module)}")
# force to use the V1 weight_loader
module.weight.weight_loader = module.weight_loader
# instruct V1 loader to treat the incoming weight as pre-sharded
module.weight.is_sharded_weight = True

def get_tp_sharding_specs(self):
Expand Down
2 changes: 1 addition & 1 deletion jax-inference-offloading/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def run(self):
'jax==0.8.0',
'jaxtyping',
'kagglehub',
'vllm[flashinfer]==0.10.2',
'vllm==0.11.2',
],
cmdclass={
'build_protos': BuildPackageProtos,
Expand Down
Loading