-
Notifications
You must be signed in to change notification settings - Fork 66
JAX inference offloading bridge #1775
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a JAX-vLLM rollout offloading bridge that enables efficient coupling between JAX training and vLLM inference for reinforcement learning post-training workloads. The bridge offloads rollout generation to vLLM while keeping training in JAX, using NCCL for direct GPU-to-GPU weight transfers.
- Implements a lightweight RPC gateway for control plane coordination between trainer and rollout engine
- Provides NCCL-based data plane for fast GPU-to-GPU weight streaming with tensor resharding
- Supports multiple transfer modes (fused, unfused, grouped) and flexible parallelism configurations (FSDP/TP)
Reviewed Changes
Copilot reviewed 59 out of 59 changed files in this pull request and generated 35 comments.
Show a summary per file
| File | Description |
|---|---|
| setup.py | Package setup with protobuf build hooks and dependencies |
| pyproject.toml | Build system configuration with ruff linting rules |
| pep517_backend.py | Custom PEP 517 build backend for protobuf compilation |
| jax_inference_offloading/vllm/ | vLLM wrapper and worker extension for weight updates |
| jax_inference_offloading/transport/ | NCCL transport implementations (star topology, tensor/model transports) |
| jax_inference_offloading/controller/ | Gateway server and client implementations for trainer/rollout coordination |
| jax_inference_offloading/models/ | Model parameter mappings for Llama3 and Gemma families |
| jax_inference_offloading/jax/ | Offloading bridge API for JAX integration |
| jax_inference_offloading/tunix/ | Tunix-specific rollout and model loading utilities |
| examples/ | Example scripts for single-node and multi-node deployments |
Comments suppressed due to low confidence (1)
jax-inference-offloading/jax_inference_offloading/controller/rollout_client.py:139
- This assignment to 'shutdown' is unnecessary as it is redefined before this value is used.
def shutdown(self):
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
jax-inference-offloading/jax_inference_offloading/controller/rollout_client.py
Outdated
Show resolved
Hide resolved
jax-inference-offloading/jax_inference_offloading/vllm/extension.py
Outdated
Show resolved
Hide resolved
jax-inference-offloading/jax_inference_offloading/transport/model/nccl_fused.py
Outdated
Show resolved
Hide resolved
jax-inference-offloading/jax_inference_offloading/transport/model/nccl_fused.py
Outdated
Show resolved
Hide resolved
jax-inference-offloading/jax_inference_offloading/tunix/load_model.py
Outdated
Show resolved
Hide resolved
jax-inference-offloading/jax_inference_offloading/transport/model/nccl_fused.py
Outdated
Show resolved
Hide resolved
|
@jreiffers PTAL |
…ollout_client.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
The new CI workflow for the jax-inference-offloading subfolder is currently failing because the base container image does not yet include the changes (git-clone.sh) from this PR. Given that cyclic dependency, I suggest we treat this workflow as a pilot and proceed with merging as-is, as long as the main CI passes, then iterate to fix the workflow and address any issues once the updated base container is available. |
@yhtang JIO is not part of the main ci? and I can't understand why the CI doesn't pick up your changes to |
|
@yhtang other question |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yhtang Could you clarify what you would like the CI here to achieve? Is it to build and publish an image with JAX inference offloading based on nightly JAX (i.e. similar to what we do with maxtext)? Is there also an intention to run the examples you've added in the CI with this image, or some other way to test it?
I think the issue is that this workflow is still using the vanilla CUDA DL base image rather than the JAX-Toolbox base image built in our workflow. As a result, it doesn’t see the updated git-clone.sh even though it’s included in this PR. Nevertheless, I've reverted git-clone.sh and will add it back when we merge the offloading CI workflow with the main workflow. My plan is to introduce two Dockerfiles: one for pure OSS installation and another based on the JAX-Toolbox base image and JAX builds. I’d handle the second Dockerfile and the corresponding CI wiring in a follow-up PR to keep the scope of this change focused. |
|
The standalone JIO CI workflow is now passing, and the remaining main CI failures (e.g. https://github.com/NVIDIA/JAX-Toolbox/actions/runs/19385835327/job/55473082584?pr=1775#step:4:2129) match those on the main branch (e.g. https://github.com/NVIDIA/JAX-Toolbox/actions/runs/19386305470/job/55474369243#step:4:2407). Per our earlier agreement that this PR is scoped to changes within the jax-inference-offloading folder, I’ll go ahead and merge this PR and keep subsequent updates similarly contained until we work together to integrate the JIO CI into the main CI. |
No description provided.